OSDN Git Service

[mlir] Explicitly track branch instructions in translation to LLVM IR
authorAlex Zinenko <zinenko@google.com>
Tue, 8 Dec 2020 15:09:20 +0000 (16:09 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 10 Dec 2020 10:08:58 +0000 (11:08 +0100)
The current implementation of the translation to LLVM IR relies on the
existence of a one-to-one mapping between MLIR blocks and LLVM IR basic blocks
in order to configure PHI nodes with appropriate source blocks. The one-to-one
mapping model is broken in presence of OpenMP operations that use LLVM's
OpenMPIRBuilder, which produces multiple blocks under the hood. This can lead
to invalid LLVM IR being emitted if OpenMPIRBuilder moved the branch operation
into a basic block different from the one it was originally created in;
specifically, a block that is not a direct predecessor could be used in the PHI
node. Instead, keep track of the mapping between MLIR LLVM dialect branch
operations and their LLVM IR counterparts and take the parent basic block of
the LLVM IR instruction at the moment of connecting the PHI nodes to
predecessors.

This behavior cannot be triggered as of now, but will be once we introduce the
conversion of OpenMP workshare loops.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D92845

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

index 996c8de..d3d2894 100644 (file)
@@ -151,6 +151,11 @@ protected:
   llvm::StringMap<llvm::Function *> functionMapping;
   DenseMap<Value, llvm::Value *> valueMapping;
   DenseMap<Block *, llvm::BasicBlock *> blockMapping;
+
+  /// A mapping between MLIR LLVM dialect terminators and LLVM IR terminators
+  /// they are converted to. This allows for conneting PHI nodes to the source
+  /// values after all operations are converted.
+  DenseMap<Operation *, llvm::Instruction *> branchMapping;
 };
 
 } // namespace LLVM
index c20b19f..057f574 100644 (file)
@@ -340,9 +340,10 @@ static Value getPHISourceValue(Block *current, Block *pred,
 
 /// Connect the PHI nodes to the results of preceding blocks.
 template <typename T>
-static void
-connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
-                const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
+static void connectPHINodes(
+    T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
+    const DenseMap<Block *, llvm::BasicBlock *> &blockMapping,
+    const DenseMap<Operation *, llvm::Instruction *> &branchMapping) {
   // Skip the first block, it cannot be branched to and its arguments correspond
   // to the arguments of the LLVM function.
   for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
@@ -355,9 +356,17 @@ connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
       auto &phiNode = numberedPhiNode.value();
       unsigned index = numberedPhiNode.index();
       for (auto *pred : bb->getPredecessors()) {
+        // Find the LLVM IR block that contains the converted terminator
+        // instruction and use it in the PHI node. Note that this block is not
+        // necessarily the same as blockMapping.lookup(pred), some operations
+        // (in particular, OpenMP operations using OpenMPIRBuilder) may have
+        // split the blocks.
+        llvm::Instruction *terminator =
+            branchMapping.lookup(pred->getTerminator());
+        assert(terminator && "missing the mapping for a terminator");
         phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
                                 bb, pred, numArguments, index)),
-                            blockMapping.lookup(pred));
+                            terminator->getParent());
       }
     }
   }
@@ -476,7 +485,7 @@ void ModuleTranslation::convertOmpOpRegions(
   }
   // Finally, after all blocks have been traversed and values mapped,
   // connect the PHI nodes to the results of preceding blocks.
-  connectPHINodes(region, valueMapping, blockMapping);
+  connectPHINodes(region, valueMapping, blockMapping, branchMapping);
 }
 
 LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
@@ -682,7 +691,9 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   // Emit branches.  We need to look up the remapped blocks and ignore the block
   // arguments that were transformed into PHI nodes.
   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
-    builder.CreateBr(blockMapping[brOp.getSuccessor()]);
+    llvm::BranchInst *branch =
+        builder.CreateBr(blockMapping[brOp.getSuccessor()]);
+    branchMapping.try_emplace(&opInst, branch);
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
@@ -699,9 +710,11 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
               .createBranchWeights(static_cast<uint32_t>(trueWeight),
                                    static_cast<uint32_t>(falseWeight));
     }
-    builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
-                         blockMapping[condbrOp.getSuccessor(0)],
-                         blockMapping[condbrOp.getSuccessor(1)], branchWeights);
+    llvm::BranchInst *branch = builder.CreateCondBr(
+        valueMapping.lookup(condbrOp.getOperand(0)),
+        blockMapping[condbrOp.getSuccessor(0)],
+        blockMapping[condbrOp.getSuccessor(1)], branchWeights);
+    branchMapping.try_emplace(&opInst, branch);
     return success();
   }
 
@@ -893,10 +906,11 @@ forwardPassthroughAttributes(Location loc, Optional<ArrayAttr> attributes,
 }
 
 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
-  // Clear the block and value mappings, they are only relevant within one
+  // Clear the block, branch value mappings, they are only relevant within one
   // function.
   blockMapping.clear();
   valueMapping.clear();
+  branchMapping.clear();
   llvm::Function *llvmFunc = functionMapping.lookup(func.getName());
 
   // Translate the debug information for this function.
@@ -964,7 +978,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Finally, after all blocks have been traversed and values mapped, connect
   // the PHI nodes to the results of preceding blocks.
-  connectPHINodes(func, valueMapping, blockMapping);
+  connectPHINodes(func, valueMapping, blockMapping, branchMapping);
   return success();
 }