OSDN Git Service

[mlir][EDSC] Allow conditionBuilder to capture the IfOp
authorNicolas Vasilache <ntv@google.com>
Fri, 17 Jul 2020 15:09:35 +0000 (11:09 -0400)
committerNicolas Vasilache <ntv@google.com>
Fri, 17 Jul 2020 15:16:26 +0000 (11:16 -0400)
When the IfOp returns values, it can easily be obtained from one of the Values.
However, when no values are returned, the information is lost.
This revision lets the caller specify a capture IfOp* to return the produced
IfOp.

Differential Revision: https://reviews.llvm.org/D84025

mlir/include/mlir/Dialect/SCF/EDSC/Builders.h
mlir/lib/Dialect/SCF/EDSC/Builders.cpp

index 1605f58..50adec2 100644 (file)
@@ -36,12 +36,15 @@ scf::ValueVector loopNestBuilder(
 /// Adapters for building if conditions using the builder and the location
 /// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted
 /// if the condition should not have an 'else' part.
-ValueRange
-conditionBuilder(TypeRange results, Value condition,
-                 function_ref<scf::ValueVector()> thenBody,
-                 function_ref<scf::ValueVector()> elseBody = nullptr);
+/// When `ifOp` is specified, the scf::IfOp is captured. This is particularly
+/// convenient for 0-result conditions.
+ValueRange conditionBuilder(TypeRange results, Value condition,
+                            function_ref<scf::ValueVector()> thenBody,
+                            function_ref<scf::ValueVector()> elseBody = nullptr,
+                            scf::IfOp *ifOp = nullptr);
 ValueRange conditionBuilder(Value condition, function_ref<void()> thenBody,
-                            function_ref<void()> elseBody = nullptr);
+                            function_ref<void()> elseBody = nullptr,
+                            scf::IfOp *ifOp = nullptr);
 
 } // namespace edsc
 } // namespace mlir
index 082c8c3..2098ca1 100644 (file)
@@ -76,14 +76,17 @@ wrapIfBody(function_ref<scf::ValueVector()> body, TypeRange expectedTypes) {
 ValueRange
 mlir::edsc::conditionBuilder(TypeRange results, Value condition,
                              function_ref<scf::ValueVector()> thenBody,
-                             function_ref<scf::ValueVector()> elseBody) {
+                             function_ref<scf::ValueVector()> elseBody,
+                             scf::IfOp *ifOp) {
   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
   assert(thenBody && "thenBody is mandatory");
 
-  auto ifOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
+  auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
       ScopedContext::getLocation(), results, condition,
       wrapIfBody(thenBody, results), wrapIfBody(elseBody, results));
-  return ifOp.getResults();
+  if (ifOp)
+    *ifOp = newOp;
+  return newOp.getResults();
 }
 
 static std::function<void(OpBuilder &, Location)>
@@ -97,14 +100,17 @@ wrapZeroResultIfBody(function_ref<void()> body) {
 
 ValueRange mlir::edsc::conditionBuilder(Value condition,
                                         function_ref<void()> thenBody,
-                                        function_ref<void()> elseBody) {
+                                        function_ref<void()> elseBody,
+                                        scf::IfOp *ifOp) {
   assert(ScopedContext::getContext() && "EDSC ScopedContext not set up");
   assert(thenBody && "thenBody is mandatory");
 
-  ScopedContext::getBuilderRef().create<scf::IfOp>(
+  auto newOp = ScopedContext::getBuilderRef().create<scf::IfOp>(
       ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody),
       elseBody ? llvm::function_ref<void(OpBuilder &, Location)>(
                      wrapZeroResultIfBody(elseBody))
                : llvm::function_ref<void(OpBuilder &, Location)>(nullptr));
+  if (ifOp)
+    *ifOp = newOp;
   return {};
 }