OSDN Git Service

[mlir] Support unranked types in func signature conversion in BufferPlacement.
authorAlexander Belyaev <pifon@google.com>
Tue, 7 Jul 2020 17:36:48 +0000 (19:36 +0200)
committerAlexander Belyaev <pifon@google.com>
Tue, 7 Jul 2020 17:43:48 +0000 (19:43 +0200)
Currently, only ranked tensor args and results can be converted to memref types.

Differential Revision: https://reviews.llvm.org/D83324

mlir/lib/Transforms/BufferPlacement.cpp
mlir/test/Transforms/buffer-placement-preparation-allowed-memref-results.mlir
mlir/test/Transforms/buffer-placement-preparation.mlir

index 71d397b..87b2687 100644 (file)
@@ -700,15 +700,19 @@ BufferAssignmentPlacer::computeAllocPosition(OpResult result) {
 BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
   // Keep all types unchanged.
   addConversion([](Type type) { return type; });
-  // A type conversion that converts ranked-tensor type to memref type.
+  // Convert RankedTensorType to MemRefType.
   addConversion([](RankedTensorType type) {
     return (Type)MemRefType::get(type.getShape(), type.getElementType());
   });
+  // Convert UnrankedTensorType to UnrankedMemRefType.
+  addConversion([](UnrankedTensorType type) {
+    return (Type)UnrankedMemRefType::get(type.getElementType(), 0);
+  });
 }
 
 /// Checks if `type` has been converted from non-memref type to memref.
 bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
-  return type.isa<MemRefType>() && !before.isa<MemRefType>();
+  return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
 }
 
 //===----------------------------------------------------------------------===//
index 97c9600..084ac38 100644 (file)
@@ -64,6 +64,15 @@ func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
 
 // -----
 
+// CHECK-LABEL: func @func_with_unranked_arg_and_result
+func @func_with_unranked_arg_and_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  return %arg0 : tensor<*xf32>
+}
+// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
+// CHECK-NEXT: return [[ARG]] : memref<*xf32>
+
+// -----
+
 // CHECK-LABEL: func @func_and_block_signature_conversion
 func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{
     cond_br %cond, ^bb1, ^bb2
index 9b0755a..064b0fd 100644 (file)
@@ -284,3 +284,9 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
 // CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]])
 // CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
 // CHECK: return
+
+// CHECK-LABEL: func @func_with_unranked_arg
+func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
+  return
+}
+// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)