OSDN Git Service

[MLIR][SPIRV] Add `UnsignedOp` trait.
authorKareemErgawy-TomTom <kareem.ergawy@gmail.com>
Wed, 6 Jan 2021 13:56:53 +0000 (14:56 +0100)
committerKareemErgawy-TomTom <kareem.ergawy@gmail.com>
Wed, 6 Jan 2021 14:28:41 +0000 (15:28 +0100)
This commit adds a new trait that can be attached to ops that have
unsigned semantics.

TODO:
- Check if other places in code can use the new attribute (possibly in this patch).
- Add a similar `SignedOp` attribute (in a new patch).

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D94068

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBitOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h
mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

index 0d6dd01..609f510 100644 (file)
@@ -514,7 +514,7 @@ def SPV_SRemOp : SPV_ArithmeticBinaryOp<"SRem", SPV_Integer, []> {
 
 // -----
 
-def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
+def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, [UnsignedOp]> {
   let summary = "Unsigned-integer division of Operand 1 divided by Operand 2.";
 
   let description = [{
@@ -546,7 +546,7 @@ def SPV_UDivOp : SPV_ArithmeticBinaryOp<"UDiv", SPV_Integer, []> {
 
 // -----
 
-def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer> {
+def SPV_UModOp : SPV_ArithmeticBinaryOp<"UMod", SPV_Integer, [UnsignedOp]> {
   let summary = "Unsigned modulo operation of Operand 1 modulo Operand 2.";
 
   let description = [{
index 1c9dbd7..289e9a2 100644 (file)
@@ -438,7 +438,7 @@ def SPV_AtomicSMinOp : SPV_AtomicUpdateWithValueOp<"AtomicSMin", []> {
 
 // -----
 
-def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
+def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", [UnsignedOp]> {
   let summary = [{
     Perform the following steps atomically with respect to any other atomic
     accesses within Scope to the same location:
@@ -480,7 +480,7 @@ def SPV_AtomicUMaxOp : SPV_AtomicUpdateWithValueOp<"AtomicUMax", []> {
 
 // -----
 
-def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", []> {
+def SPV_AtomicUMinOp : SPV_AtomicUpdateWithValueOp<"AtomicUMin", [UnsignedOp]> {
   let summary = [{
     Perform the following steps atomically with respect to any other atomic
     accesses within Scope to the same location:
index 2ed1101..a9603ad 100644 (file)
@@ -3115,6 +3115,8 @@ def InModuleScope : PredOpTrait<
   "op must appear in a module-like op's block",
   CPred<"isDirectInModuleLikeOp($_op.getParentOp())">>;
 
+def UnsignedOp : NativeOpTrait<"spirv::UnsignedOp">;
+
 //===----------------------------------------------------------------------===//
 // SPIR-V opcode specification
 //===----------------------------------------------------------------------===//
index 3df9798..173a031 100644 (file)
@@ -232,7 +232,8 @@ def SPV_BitFieldSExtractOp : SPV_BitFieldExtractOp<"BitFieldSExtract", []> {
 
 // -----
 
-def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract", []> {
+def SPV_BitFieldUExtractOp : SPV_BitFieldExtractOp<"BitFieldUExtract",
+                                                   [UnsignedOp]> {
   let summary = "Extract a bit field from an object, without sign extension.";
 
   let description = [{
index 726f79f..20d4afd 100644 (file)
@@ -196,7 +196,10 @@ def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", SPV_Float, SPV_Integer, []> {
 
 // -----
 
-def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", SPV_Float, SPV_Integer, []> {
+def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF",
+                                   SPV_Float,
+                                   SPV_Integer,
+                                   [UnsignedOp]> {
   let summary = [{
     Convert value numerically from unsigned integer to floating point.
   }];
@@ -298,7 +301,10 @@ def SPV_SConvertOp : SPV_CastOp<"SConvert", SPV_Integer, SPV_Integer, []> {
 
 // -----
 
-def SPV_UConvertOp : SPV_CastOp<"UConvert", SPV_Integer, SPV_Integer, []> {
+def SPV_UConvertOp : SPV_CastOp<"UConvert",
+                                SPV_Integer,
+                                SPV_Integer,
+                                [UnsignedOp]> {
   let summary = [{
     Convert unsigned width. This is either a truncate or a zero extend.
   }];
index 71254c1..46ff3d9 100644 (file)
@@ -869,7 +869,9 @@ def SPV_SelectOp : SPV_Op<"Select",
 
 // -----
 
-def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
+def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan",
+                                             SPV_Integer,
+                                             [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than  Operand 2.
   }];
@@ -902,7 +904,9 @@ def SPV_UGreaterThanOp : SPV_LogicalBinaryOp<"UGreaterThan", SPV_Integer, []> {
 
 // -----
 
-def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integer, []> {
+def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual",
+                                                  SPV_Integer,
+                                                  [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is greater than or equal to
     Operand 2.
@@ -936,7 +940,9 @@ def SPV_UGreaterThanEqualOp : SPV_LogicalBinaryOp<"UGreaterThanEqual", SPV_Integ
 
 // -----
 
-def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
+def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan",
+                                          SPV_Integer,
+                                          [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than Operand 2.
   }];
@@ -970,7 +976,7 @@ def SPV_ULessThanOp : SPV_LogicalBinaryOp<"ULessThan", SPV_Integer, []> {
 // -----
 
 def SPV_ULessThanEqualOp :
-  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, []> {
+  SPV_LogicalBinaryOp<"ULessThanEqual", SPV_Integer, [UnsignedOp]> {
   let summary = [{
     Unsigned-integer comparison if Operand 1 is less than or equal to
     Operand 2.
index 9a22a93..89eaf1d 100644 (file)
@@ -631,7 +631,9 @@ def SPV_GroupNonUniformSMinOp :
 // -----
 
 def SPV_GroupNonUniformUMaxOp :
-    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax", SPV_Integer, []> {
+    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMax",
+                                    SPV_Integer,
+                                    [UnsignedOp]> {
   let summary = [{
     An unsigned integer maximum group operation of all Value operands
     contributed by active invocations in the group.
@@ -681,7 +683,9 @@ def SPV_GroupNonUniformUMaxOp :
 // -----
 
 def SPV_GroupNonUniformUMinOp :
-    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin", SPV_Integer, []> {
+    SPV_GroupNonUniformArithmeticOp<"GroupNonUniformUMin",
+                                    SPV_Integer,
+                                    [UnsignedOp]> {
   let summary = [{
     An unsigned integer minimum group operation of all Value operands
     contributed by active invocations in the group.
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h
new file mode 100644 (file)
index 0000000..3e67923
--- /dev/null
@@ -0,0 +1,30 @@
+//===- SPIRVOps.h - MLIR SPIR-V operation traits ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares C++ classes for some of operation traits in the SPIR-V
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+#define MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace mlir {
+namespace OpTrait {
+namespace spirv {
+
+template <typename ConcreteType>
+class UnsignedOp : public TraitBase<ConcreteType, UnsignedOp> {};
+
+} // namespace spirv
+} // namespace OpTrait
+} // namespace mlir
+
+#endif // MLIR_DIALECT_SPIRV_IR_SPIRVOPTRAITS_H_
index 5e9a46d..2de2bc0 100644 (file)
@@ -13,6 +13,7 @@
 #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
index 164ed36..88d0a81 100644 (file)
@@ -187,36 +187,6 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                                                    offset);
 }
 
-/// Returns true if the operator is operating on unsigned integers.
-/// TODO: Have a TreatOperandsAsUnsignedInteger trait and bake the information
-/// to the ops themselves.
-template <typename SPIRVOp>
-bool isUnsignedOp() {
-  return false;
-}
-
-#define CHECK_UNSIGNED_OP(SPIRVOp)                                             \
-  template <>                                                                  \
-  bool isUnsignedOp<SPIRVOp>() {                                               \
-    return true;                                                               \
-  }
-
-CHECK_UNSIGNED_OP(spirv::AtomicUMaxOp)
-CHECK_UNSIGNED_OP(spirv::AtomicUMinOp)
-CHECK_UNSIGNED_OP(spirv::BitFieldUExtractOp)
-CHECK_UNSIGNED_OP(spirv::ConvertUToFOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMaxOp)
-CHECK_UNSIGNED_OP(spirv::GroupNonUniformUMinOp)
-CHECK_UNSIGNED_OP(spirv::UConvertOp)
-CHECK_UNSIGNED_OP(spirv::UDivOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::UGreaterThanOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanEqualOp)
-CHECK_UNSIGNED_OP(spirv::ULessThanOp)
-CHECK_UNSIGNED_OP(spirv::UModOp)
-
-#undef CHECK_UNSIGNED_OP
-
 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
 static bool isAllocationSupported(MemRefType t) {
   // Currently only support workgroup local memory allocations with static
@@ -334,7 +304,8 @@ public:
     auto dstType = this->typeConverter.convertType(operation.getType());
     if (!dstType)
       return failure();
-    if (isUnsignedOp<SPIRVOp>() && dstType != operation.getType()) {
+    if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
+        dstType != operation.getType()) {
       return operation.emitError(
           "bitwidth emulation is not implemented yet on unsigned op");
     }
@@ -799,7 +770,7 @@ CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
   switch (cmpIOp.getPredicate()) {
 #define DISPATCH(cmpPredicate, spirvOp)                                        \
   case cmpPredicate:                                                           \
-    if (isUnsignedOp<spirvOp>() &&                                             \
+    if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
         operandType != this->typeConverter.convertType(operandType)) {         \
       return cmpIOp.emitError(                                                 \
           "bitwidth emulation is not implemented yet on unsigned op");         \