OSDN Git Service

[MLIR][Shape] Allow `cstr_broadcastable` to accept extent tensors
authorFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 14:37:19 +0000 (14:37 +0000)
committerFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 14:39:44 +0000 (14:39 +0000)
Differential Revision: https://reviews.llvm.org/D84155

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir
mlir/test/Dialect/Shape/remove-shape-constraints.mlir

index 8508241..46400c8 100644 (file)
@@ -610,8 +610,9 @@ def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
 def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
   let summary = "Determines if 2 shapes can be successfully broadcasted";
   let description = [{
-    Given 2 input shapes, return a witness specifying if they are broadcastable.
-    This broadcastable follows the same logic as what shape.broadcast documents.
+    Given two input shapes or extent tensors, return a witness specifying if
+    they are broadcastable. This broadcastable follows the same logic as what
+    shape.broadcast documents.
 
     "cstr" operations represent runtime assertions.
 
@@ -622,10 +623,11 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", [Commutative]> {
     ```
   }];
 
-  let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
+                       Shape_ShapeOrExtentTensorType:$rhs);
   let results = (outs Shape_WitnessType:$result);
 
-  let assemblyFormat = "$lhs `,` $rhs attr-dict";
+  let assemblyFormat = "$lhs `,` $rhs `:` type($lhs) `,` type($rhs) attr-dict";
 
   let hasCanonicalizer = 1;
   let hasFolder = 1;
index a58b230..156063e 100644 (file)
@@ -431,7 +431,7 @@ func @f() {
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [3, 1]
   %cs1 = shape.const_shape [1, 5]
-  %0 = shape.cstr_broadcastable %cs0, %cs1
+  %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -447,7 +447,7 @@ func @static_non_broadcastable() {
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [1, 3]
   %cs1 = shape.const_shape [1, 5]
-  %0 = shape.cstr_broadcastable %cs0, %cs1
+  %0 = shape.cstr_broadcastable %cs0, %cs1 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -461,7 +461,7 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
   %cs0 = shape.const_shape [1,3]
-  %0 = shape.cstr_broadcastable %arg0, %cs0
+  %0 = shape.cstr_broadcastable %arg0, %cs0 : !shape.shape, !shape.shape
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -473,7 +473,20 @@ func @f(%arg0 : !shape.shape) {
   // CHECK-NEXT: shape.const_witness true
   // CHECK-NEXT: consume.witness
   // CHECK-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg0
+  %0 = shape.cstr_broadcastable %arg0, %arg0 : !shape.shape, !shape.shape
+  "consume.witness"(%0) : (!shape.witness) -> ()
+  return
+}
+
+// -----
+
+// Broadcastable canonicalization also works on extent tensors.
+// CHECK-LABEL: func @broadcastable_on_extent_tensors
+func @broadcastable_on_extent_tensors(%arg : tensor<?xindex>) {
+  // CHECK-NEXT: shape.const_witness true
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = shape.cstr_broadcastable %arg, %arg : tensor<?xindex>, tensor<?xindex>
   "consume.witness"(%0) : (!shape.witness) -> ()
   return
 }
@@ -560,7 +573,7 @@ func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
   // CHECK-NEXT: return
   %0 = shape.const_shape []
   %1 = shape.shape_of %arg0 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -577,7 +590,7 @@ func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
   // CHECK-NEXT: return
   %0 = shape.shape_of %arg0 : tensor<?xf32>
   %1 = shape.shape_of %arg1 : tensor<?xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
@@ -592,7 +605,7 @@ func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<i
   // CHECK-NEXT: return
   %0 = shape.shape_of %arg1 : tensor<index>
   %1 = shape.shape_of %arg0 : tensor<*xf32>
-  %2 = shape.cstr_broadcastable %0, %1
+  %2 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   "consume.witness"(%2) : (!shape.witness) -> ()
   return
 }
index c6f5251..30cf29a 100644 (file)
@@ -86,7 +86,7 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
 func @test_constraints() {
   %0 = shape.const_shape []
   %1 = shape.const_shape [1, 2, 3]
-  %w0 = shape.cstr_broadcastable %0, %1
+  %w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
   %w1 = shape.cstr_eq %0, %1
   %w2 = shape.const_witness true
   %w3 = shape.const_witness false
@@ -98,6 +98,12 @@ func @test_constraints() {
   return
 }
 
+func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
+                                      %rhs : tensor<?xindex>) {
+  %w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
+  return
+}
+
 func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
   %product = shape.mul %lhs, %rhs
   return %product: !shape.size
index 69887c6..31bb7bd 100644 (file)
@@ -11,7 +11,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // REPLACE: shape.assuming %[[WITNESS]]
   // CANON-NEXT: test.source
   // CANON-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
   %1 = shape.assuming %0 -> index {
     %2 = "test.source"() : () -> (index)
     shape.assuming_yield %2 : index
@@ -45,7 +45,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
 func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
   // CANON-NEXT: test.source
   // CANON-NEXT: return
-  %0 = shape.cstr_broadcastable %arg0, %arg1
+  %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
   %1 = shape.cstr_eq %arg0, %arg1
   %2 = shape.assuming_all %0, %1
   %3 = shape.assuming %0 -> index {