From 3ccf4a5bd1099bfba544bec7ddbe610cc9531bb2 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 20 May 2020 16:00:57 +0200 Subject: [PATCH] [mlir] ensureRegionTerminator: take OpBuilder The SingleBlockImplicitTerminator op trait provides a function `ensureRegionTerminator` that injects an appropriate terminator into the block if necessary, which is used during operation constructing and parsing. Currently, this function directly modifies the IR using low-level APIs on Operation and Block. If this function is called from a conversion pattern, these manipulations are not reflected in the ConversionPatternRewriter and thus cannot be undone or, worse, lead to tricky memory errors and malformed IR. Change `ensureRegionTerminator` to take an instance of `OpBuilder` instead of `Builder`, and use it to construct the block and the terminator when required. Maintain overloads taking an instance of `Builder` and creating a simple `OpBuilder` to use in parsers, which don't have an `OpBuilder` and cannot interact with the dialect conversion mechanism. This change was one of the reasons to make `::build` accept an `OpBuilder`. Differential Revision: https://reviews.llvm.org/D80138 --- mlir/include/mlir/IR/OpDefinition.h | 39 ++++++++++++++++++++++++------------- mlir/lib/IR/Operation.cpp | 20 ++++++++++++++----- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index d35efbb2d31..bf5bd70c2b7 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -79,17 +79,12 @@ namespace impl { /// is empty, insert a new block first. `buildTerminatorOp` should return the /// terminator operation to insert. void ensureRegionTerminator( - Region ®ion, Location loc, - function_ref buildTerminatorOp); -/// Templated version that fills the generates the provided operation type. -template -void ensureRegionTerminator(Region ®ion, Builder &builder, Location loc) { - ensureRegionTerminator(region, loc, [&](OpBuilder &b) { - OperationState state(loc, OpTy::getOperationName()); - OpTy::build(b, state); - return Operation::create(state); - }); -} + Region ®ion, OpBuilder &builder, Location loc, + function_ref buildTerminatorOp); +void ensureRegionTerminator( + Region ®ion, Builder &builder, Location loc, + function_ref buildTerminatorOp); + } // namespace impl /// This is the concrete base class that holds the operation pointer and has @@ -1077,6 +1072,15 @@ public: template struct SingleBlockImplicitTerminator { template class Impl : public TraitBase { + private: + /// Builds a terminator operation without relying on OpBuilder APIs to avoid + /// cyclic header inclusion. + static Operation *buildTerminator(OpBuilder &builder, Location loc) { + OperationState state(loc, TerminatorOpType::getOperationName()); + TerminatorOpType::build(builder, state); + return Operation::create(state); + } + public: static LogicalResult verifyTrait(Operation *op) { for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { @@ -1112,10 +1116,19 @@ template struct SingleBlockImplicitTerminator { } /// Ensure that the given region has the terminator required by this trait. + /// If OpBuilder is provided, use it to build the terminator and notify the + /// OpBuilder litsteners accoridngly. If only a Builder is provided, locally + /// construct an OpBuilder with no listeners; this should only be used if no + /// OpBuilder is available at the call site, e.g., in the parser. static void ensureTerminator(Region ®ion, Builder &builder, Location loc) { - ::mlir::impl::template ensureRegionTerminator( - region, builder, loc); + ::mlir::impl::ensureRegionTerminator(region, builder, loc, + buildTerminator); + } + static void ensureTerminator(Region ®ion, OpBuilder &builder, + Location loc) { + ::mlir::impl::ensureRegionTerminator(region, builder, loc, + buildTerminator); } Block *getBody(unsigned idx = 0) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 9539a8e31c3..f83bc0a3b97 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1099,17 +1099,27 @@ Value impl::foldCastOp(Operation *op) { /// is empty, insert a new block first. `buildTerminatorOp` should return the /// terminator operation to insert. void impl::ensureRegionTerminator( - Region ®ion, Location loc, - function_ref buildTerminatorOp) { + Region ®ion, OpBuilder &builder, Location loc, + function_ref buildTerminatorOp) { + OpBuilder::InsertionGuard guard(builder); if (region.empty()) - region.push_back(new Block); + builder.createBlock(®ion); Block &block = region.back(); if (!block.empty() && block.back().isKnownTerminator()) return; - OpBuilder builder(loc.getContext()); - block.push_back(buildTerminatorOp(builder)); + builder.setInsertionPointToEnd(&block); + builder.insert(buildTerminatorOp(builder, loc)); +} + +/// Create a simple OpBuilder and forward to the OpBuilder version of this +/// function. +void impl::ensureRegionTerminator( + Region ®ion, Builder &builder, Location loc, + function_ref buildTerminatorOp) { + OpBuilder opBuilder(builder.getContext()); + ensureRegionTerminator(region, opBuilder, loc, buildTerminatorOp); } //===----------------------------------------------------------------------===// -- 2.11.0