From 774f1d3ffd458d6cb82d5039758ef1cf6370957f Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Mon, 30 Nov 2020 15:20:30 -0800 Subject: [PATCH] [mlir] Small cleanups to func-bufferize/finalizing-bufferize - Address TODO in scf-bufferize: the argument materialization issue is now fixed and the code is now in Transforms/Bufferize.cpp - Tighten up finalizing-bufferize to avoid creating invalid IR when operand types potentially change - Tidy up the testing of func-bufferize, and move appropriate tests to a new finalizing-bufferize.mlir - The new stricter checking in finalizing-bufferize revealed that we needed a DimOp conversion pattern (found when integrating into npcomp). Previously, the converion infrastructure was blindly changing the operand type during finalization, which happened to work due to DimOp's tensor/memref polymorphism, but is generally not encouraged (the new pattern is the way to tell the conversion infrastructure that it is legal to change that type). --- mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp | 15 ------ .../Dialect/StandardOps/Transforms/Bufferize.cpp | 18 +++++++ mlir/lib/Transforms/Bufferize.cpp | 12 +++-- mlir/test/Dialect/Standard/bufferize.mlir | 14 ++++- .../Dialect/Standard/func-bufferize-partial.mlir | 59 -------------------- mlir/test/Dialect/Standard/func-bufferize.mlir | 63 ++++++++++++---------- mlir/test/Transforms/finalizing-bufferize.mlir | 28 ++++++++++ 7 files changed, 101 insertions(+), 108 deletions(-) delete mode 100644 mlir/test/Dialect/Standard/func-bufferize-partial.mlir create mode 100644 mlir/test/Transforms/finalizing-bufferize.mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 57d605b3491..7cf0dfabd91 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -27,21 +27,6 @@ struct SCFBufferizePass : public SCFBufferizeBase { OwningRewritePatternList patterns; ConversionTarget target(*context); - // TODO: Move this to BufferizeTypeConverter's constructor. - // - // This doesn't currently play well with "finalizing" bufferizations (ones - // that expect all materializations to be gone). In particular, there seems - // to at least be a double-free in the dialect conversion framework - // when this materialization gets inserted and then folded away because - // it is marked as illegal. - typeConverter.addArgumentMaterialization( - [](OpBuilder &builder, RankedTensorType type, ValueRange inputs, - Location loc) -> Value { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, type, inputs[0]); - }); - populateBufferizeMaterializationLegality(target); populateSCFStructuralTypeConversionsAndLegality(context, typeConverter, patterns, target); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp index 9056fbc25e1..8b47e88677e 100644 --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -21,6 +21,21 @@ using namespace mlir; namespace { +class BufferizeDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(DimOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + DimOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op, adaptor.memrefOrTensor(), + adaptor.index()); + return success(); + } +}; +} // namespace + +namespace { class BufferizeDynamicTensorFromElementsOp : public OpConversionPattern { public: @@ -148,6 +163,7 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context, OwningRewritePatternList &patterns) { patterns.insert< // clang-format off + BufferizeDimOp, BufferizeDynamicTensorFromElementsOp, BufferizeExtractElementOp, BufferizeSelectOp, @@ -178,6 +194,8 @@ struct StdBufferizePass : public StdBufferizeBase { return typeConverter.isLegal(op.getType()) || !op.condition().getType().isa(); }); + target.addDynamicallyLegalOp( + [&](DimOp op) { return typeConverter.isLegal(op); }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp index 1811ac8bdfb..66b1cc65646 100644 --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -105,13 +105,17 @@ struct FinalizingBufferizePass populateEliminateBufferizeMaterializationsPatterns(context, typeConverter, patterns); - target.addIllegalOp(); // If all result types are legal, and all block arguments are legal (ensured // by func conversion above), then all types in the program are legal. - target.markUnknownOpDynamicallyLegal([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); - }); + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents + // populateEliminateBufferizeMaterializationsPatterns from updating the + // types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir index 8cc05ff2064..27769c52d9e 100644 --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,5 +1,16 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[INDEX:.*]]: index) -> index { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: %[[EXTENT:.*]] = dim %[[MEMREF]], %[[INDEX]] : memref +// CHECK: return %[[EXTENT]] : index +func @dim(%arg0: tensor, %arg1: index) -> index { + %0 = dim %arg0, %arg1 : tensor + return %0 : index +} + // CHECK-LABEL: func @dynamic_tensor_from_elements( // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { @@ -7,7 +18,8 @@ // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ELEM:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: %[[ARG_MEMREF:.*]] = tensor_to_memref %[[ARG]] : memref<*xf32> +// CHECK: %[[ELEM:.*]] = dim %[[ARG_MEMREF]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref // CHECK: scf.yield // CHECK: } diff --git a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir b/mlir/test/Dialect/Standard/func-bufferize-partial.mlir deleted file mode 100644 index 43ea4591e4e..00000000000 --- a/mlir/test/Dialect/Standard/func-bufferize-partial.mlir +++ /dev/null @@ -1,59 +0,0 @@ -// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func @block_arguments( -// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref -// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref -// CHECK: br ^bb1(%[[M1]] : memref) -// CHECK: ^bb1(%[[BBARG:.*]]: memref): -// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref -// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref -// CHECK: return %[[M2]] : memref -func @block_arguments(%arg0: tensor) -> tensor { - br ^bb1(%arg0: tensor) -^bb1(%bbarg: tensor): - return %bbarg : tensor -} - -// CHECK-LABEL: func @partial() -// CHECK-SAME: memref -func @partial() -> tensor { - // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor - // CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref - %0 = "test.source"() : () -> tensor - // CHECK-NEXT: return %[[MEM]] : memref - return %0 : tensor -} - -// CHECK-LABEL: func @region_op -// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref -func @region_op(%arg0: i1) -> tensor { - // CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor) - %0 = scf.if %arg0 -> (tensor) { - // CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor - %1 = "test.source"() : () -> tensor - // CHECK-NEXT: scf.yield %[[SRC]] : tensor - scf.yield %1 : tensor - // CHECK-NEXT: else - } else { - // CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor - %1 = "test.other_source"() : () -> tensor - // CHECK-NEXT: scf.yield %[[OSRC]] : tensor - scf.yield %1 : tensor - } - // CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref - // CHECK: return %[[MEM]] : memref - return %0 : tensor -} - -// ----- - -func @failed_to_legalize(%arg0: tensor) -> tensor { - %0 = constant true - cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) - ^bb1(%bbarg0: tensor): - // expected-error @+1 {{failed to legalize operation 'test.terminator'}} - "test.terminator"() : () -> () - ^bb2(%bbarg1: tensor): - return %bbarg1 : tensor -} diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir index d02db99aecd..de2f75c4a29 100644 --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -1,39 +1,29 @@ -// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @identity( -// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: return %[[ARG]] : memref +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: return %[[MEMREF]] : memref func @identity(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @block_arguments( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: br ^bb1(%[[ARG]] : memref) +// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref +// CHECK: br ^bb1(%[[M1]] : memref) // CHECK: ^bb1(%[[BBARG:.*]]: memref): -// CHECK: return %[[BBARG]] : memref +// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref +// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref +// CHECK: return %[[M2]] : memref func @block_arguments(%arg0: tensor) -> tensor { br ^bb1(%arg0: tensor) ^bb1(%bbarg: tensor): return %bbarg : tensor } -// CHECK-LABEL: func @eliminate_target_materialization( -// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: return %[[ARG]] : memref -func @eliminate_target_materialization(%arg0: tensor) -> memref { - %0 = tensor_to_memref %arg0 : memref - return %0 : memref -} - -// CHECK-LABEL: func @eliminate_source_materialization( -// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: return %[[ARG]] : memref -func @eliminate_source_materialization(%arg0: memref) -> tensor { - %0 = tensor_load %arg0 : memref - return %0 : tensor -} - // CHECK-LABEL: func private @source() -> memref // CHECK-LABEL: func @call_source() -> memref { // CHECK: %[[RET:.*]] = call @source() : () -> memref @@ -43,11 +33,11 @@ func @call_source() -> tensor { %0 = call @source() : () -> tensor return %0 : tensor } - -// CHECK-LABEL: func private @sink(memref) // CHECK-LABEL: func @call_sink( -// CHECK-SAME: %[[ARG:.*]]: memref) { -// CHECK: call @sink(%[[ARG]]) : (memref) -> () +// CHECK-SAME: %[[ARG:.*]]: memref) { +// CHECK: %[[TENSOR:.*]] = tensor_load %[[ARG]] : memref +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: call @sink(%[[MEMREF]]) : (memref) -> () // CHECK: return func private @sink(tensor) func @call_sink(%arg0: tensor) { @@ -55,10 +45,25 @@ func @call_sink(%arg0: tensor) { return } +// CHECK-LABEL: func @unconverted_op_in_body() -> memref { +// CHECK: %[[TENSOR:.*]] = "test.source"() : () -> tensor +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: return %[[MEMREF]] : memref +func @unconverted_op_in_body() -> tensor { + %0 = "test.source"() : () -> tensor + return %0 : tensor +} + // ----- -func @failed_to_legalize() -> tensor { - // expected-error @+1 {{failed to legalize operation 'test.source'}} - %0 = "test.source"() : () -> (tensor) - return %0 : tensor +// Because this pass updates block arguments, it needs to also atomically +// update all terminators and issue an error if that is not possible. +func @unable_to_update_terminator(%arg0: tensor) -> tensor { + %0 = constant true + cond_br %0, ^bb1(%arg0: tensor), ^bb2(%arg0: tensor) + ^bb1(%bbarg0: tensor): + // expected-error @+1 {{failed to legalize operation 'test.terminator'}} + "test.terminator"() : () -> () + ^bb2(%bbarg1: tensor): + return %bbarg1 : tensor } diff --git a/mlir/test/Transforms/finalizing-bufferize.mlir b/mlir/test/Transforms/finalizing-bufferize.mlir new file mode 100644 index 00000000000..5c09664776e --- /dev/null +++ b/mlir/test/Transforms/finalizing-bufferize.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @eliminate_materializations( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { +// CHECK: return %[[ARG]] : memref +func @eliminate_materializations(%arg0: memref) -> memref { + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref +} + +// ----- + +func @unable_to_convert_lone_tensor_to_memref() -> memref { + // expected-error @+1 {{failed to legalize operation 'test.source'}} + %0 = "test.source"() : () -> tensor + %1 = tensor_to_memref %0 : memref + return %1 : memref +} + +// ----- + +func @unable_to_convert_lone_tensor_load(%arg0: memref) { + %0 = tensor_load %arg0 : memref + // expected-error @+1 {{failed to legalize operation 'test.sink'}} + "test.sink"(%0) : (tensor) -> () + return +} -- 2.11.0