#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define SHAPE_OPS
include "mlir/Dialect/Shape/IR/ShapeBase.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
nothing else. They should not exist after a program is fully lowered and
ready to execute.
}];
- let arguments = (ins Shape_WitnessType);
- let regions = (region SizedRegion<1>:$thenRegion);
+ let arguments = (ins Shape_WitnessType:$witness);
+ let regions = (region SizedRegion<1>:$doRegion);
let results = (outs Variadic<AnyType>:$results);
+
+ let printer = [{ return ::print(p, *this); }];
+ let parser = [{ return ::parse$cppClass(parser, result); }];
}
-def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
+def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
+ [NoSideEffect, ReturnLike, Terminator]> {
let summary = "Yield operation";
let description = [{
This yield operation represents a return operation within the assert_and_exec
}];
let arguments = (ins Variadic<AnyType>:$operands);
+
+ let builders = [
+ OpBuilder<"OpBuilder &builder, OperationState &result",
+ [{ /* nothing to do */ }]>
+ ];
}
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
MLIRShapeOpsIncGen
LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
}
//===----------------------------------------------------------------------===//
+// AssumingOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseAssumingOp(OpAsmParser &parser,
+ OperationState &result) {
+ result.regions.reserve(1);
+ Region *doRegion = result.addRegion();
+
+ auto &builder = parser.getBuilder();
+ OpAsmParser::OperandType cond;
+ if (parser.parseOperand(cond) ||
+ parser.resolveOperand(cond, builder.getType<WitnessType>(),
+ result.operands))
+ return failure();
+
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+
+ // Parse the region and add a terminator if elided.
+ if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
+ return failure();
+ AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
+
+ // Parse the optional attribute list.
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+ return success();
+}
+
+static void print(OpAsmPrinter &p, AssumingOp op) {
+ bool yieldsResults = !op.results().empty();
+
+ p << AssumingOp::getOperationName() << " " << op.witness();
+ if (yieldsResults) {
+ p << " -> (" << op.getResultTypes() << ")";
+ }
+ p.printRegion(op.doRegion(),
+ /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/yieldsResults);
+ p.printOptionalAttrDict(op.getAttrs());
+}
+
+//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
%w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
- "shape.assuming"(%w3) ( {
+ shape.assuming %w3 -> !shape.shape {
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
"shape.assuming_yield"(%2) : (!shape.shape) -> ()
- }) : (!shape.witness) -> !shape.shape
+ }
return
}