OSDN Git Service

[MLIR][Shape] Fold `shape.shape_eq`
authorFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 12:24:02 +0000 (12:24 +0000)
committerFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 12:25:53 +0000 (12:25 +0000)
Fold `shape.shape_eq`.

Differential Revision: https://reviews.llvm.org/D82533

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index 090b4c6..8508241 100644 (file)
@@ -154,6 +154,7 @@ def Shape_ShapeEqOp : Shape_Op<"shape_eq", [Commutative, NoSideEffect]> {
   let results = (outs I1:$result);
 
   let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs)";
+  let hasFolder = 1;
 }
 
 def Shape_FromExtentsOp : Shape_Op<"from_extents", [NoSideEffect]> {
index b983968..92392b0 100644 (file)
@@ -476,6 +476,20 @@ void ConstSizeOp::getAsmResultNames(
 OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
 
 //===----------------------------------------------------------------------===//
+// ShapeEqOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShapeEqOp::fold(ArrayRef<Attribute> operands) {
+  auto lhs = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (lhs == nullptr)
+    return {};
+  auto rhs = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
+  if (rhs == nullptr)
+    return {};
+  return BoolAttr::get(lhs == rhs, getContext());
+}
+
+//===----------------------------------------------------------------------===//
 // IndexToSizeOp
 //===----------------------------------------------------------------------===//
 
index 4e320f3..a58b230 100644 (file)
@@ -596,3 +596,56 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
+
+// -----
+
+// Fold `shape_eq` for equal and constant shapes.
+// CHECK-LABEL: @shape_eq_fold_1
+func @shape_eq_fold_1() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant true
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3]
+  %b = shape.const_shape [1, 2, 3]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for different but constant shapes of same length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant false
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3]
+  %b = shape.const_shape [4, 5, 6]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Fold `shape_eq` for different but constant shapes of different length.
+// CHECK-LABEL: @shape_eq_fold_0
+func @shape_eq_fold_0() -> i1 {
+  // CHECK: %[[RESULT:.*]] = constant false
+  // CHECK: return %[[RESULT]] : i1
+  %a = shape.const_shape [1, 2, 3, 4, 5, 6]
+  %b = shape.const_shape [1, 2, 3]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}
+
+// -----
+
+// Do not fold `shape_eq` for non-constant shapes.
+// CHECK-LABEL: @shape_eq_do_not_fold
+// CHECK-SAME: (%[[A:.*]]: !shape.shape) -> i1
+func @shape_eq_do_not_fold(%a : !shape.shape) -> i1 {
+  // CHECK: %[[B:.*]] = shape.const_shape [4, 5, 6]
+  // CHECK: %[[RESULT:.*]] = shape.shape_eq %[[A]], %[[B]] : !shape.shape, !shape.shape
+  // CHECK: return %[[RESULT]] : i1
+  %b = shape.const_shape [4, 5, 6]
+  %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
+  return %result : i1
+}