From fefe4366c3bdd03552c448972930a0f7df328c24 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 28 May 2020 09:05:24 -0700 Subject: [PATCH] [mlir] Use ValueRange instead of ArrayRef This allows constructing operand adaptor from existing op (useful for commonalizing verification as I want to do in a follow up). I also add ability to use member initializers for the generated adaptor constructors for convenience. Differential Revision: https://reviews.llvm.org/D80667 --- .../StandardToLLVM/ConvertStandardToLLVM.h | 4 +-- mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h | 2 +- mlir/include/mlir/TableGen/OpClass.h | 30 +++++++++++++--- mlir/include/mlir/TableGen/Operator.h | 3 ++ .../Conversion/StandardToLLVM/StandardToLLVM.cpp | 9 +++-- mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp | 5 +-- mlir/lib/TableGen/OpClass.cpp | 39 ++++++++++++++++---- mlir/lib/TableGen/Operator.cpp | 4 +++ mlir/test/mlir-tblgen/op-decl.td | 16 ++++----- mlir/test/mlir-tblgen/op-operand.td | 6 ++-- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 41 ++++++++++++++-------- 11 files changed, 112 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 2eae578fc96..c241de6ff6f 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -438,12 +438,12 @@ public: // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ArrayRef indices, + Value descriptor, ValueRange indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const; Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, - ArrayRef indices, ConversionPatternRewriter &rewriter, + ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const; protected: diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h index 1fa668d7ddc..f0a429941fb 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -124,7 +124,7 @@ Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, // with AffineMap that has static strides. Extend to handle dynamic strides. spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ArrayRef indices, Location loc, + ValueRange indices, Location loc, OpBuilder &builder); /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h index e8f73c605df..694fed767e3 100644 --- a/mlir/include/mlir/TableGen/OpClass.h +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -86,6 +86,7 @@ public: OpMethod(StringRef retType, StringRef name, StringRef params, Property property, bool declOnly); + virtual ~OpMethod() = default; OpMethodBody &body(); @@ -96,13 +97,13 @@ public: bool isPrivate() const; // Writes the method as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; + virtual void writeDeclTo(raw_ostream &os) const; // Writes the method as a definition to the given `os`. `namePrefix` is the // prefix to be prepended to the method name (typically namespaces for // qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; -private: +protected: Property properties; // Whether this method only contains a declaration. bool isDeclOnly; @@ -110,6 +111,26 @@ private: OpMethodBody methodBody; }; +// Class for holding an op's constructor method for C++ code emission. +class OpConstructor : public OpMethod { +public: + OpConstructor(StringRef retType, StringRef name, StringRef params, + Property property, bool declOnly) + : OpMethod(retType, name, params, property, declOnly){}; + + // Add member initializer to constructor initializing `name` with `value`. + void addMemberInitializer(StringRef name, StringRef value); + + // Writes the method as a definition to the given `os`. `namePrefix` is the + // prefix to be prepended to the method name (typically namespaces for + // qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; + +private: + // Member initializers. + std::string memberInitializers; +}; + // A class used to emit C++ classes from Tablegen. Contains a list of public // methods and a list of private fields to be emitted. class Class { @@ -121,7 +142,7 @@ public: OpMethod::Property = OpMethod::MP_None, bool declOnly = false); - OpMethod &newConstructor(StringRef params = "", bool declOnly = false); + OpConstructor &newConstructor(StringRef params = "", bool declOnly = false); // Creates a new field in this class. void newField(StringRef type, StringRef name, StringRef defaultValue = ""); @@ -136,6 +157,7 @@ public: protected: std::string className; + SmallVector constructors; SmallVector methods; SmallVector fields; }; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 040f52314ce..cce754dd345 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -58,6 +58,9 @@ public: // Returns this op's C++ class name prefixed with namespaces. std::string getQualCppClassName() const; + // Returns the name of op's adaptor C++ class. + std::string getAdaptorName() const; + /// A class used to represent the decorators of an operator variable, i.e. /// argument or result. struct VariableDecorator { diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index cbe6da31add..8cc2315ddd1 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -795,8 +795,8 @@ Value ConvertToLLVMPattern::linearizeSubscripts( } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, Type elementTypePtr, Value descriptor, - ArrayRef indices, ArrayRef strides, int64_t offset, + Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, + ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); @@ -818,8 +818,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr( } Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type, - Value memRefDesc, - ArrayRef indices, + Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); @@ -2602,7 +2601,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern { // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef shape, ArrayRef dynamicSizes, + ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp index dfc2728ef71..6458756dec6 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -579,7 +579,7 @@ Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::AccessChainOp mlir::spirv::getElementPtr( SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, - ArrayRef indices, Location loc, OpBuilder &builder) { + ValueRange indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; @@ -591,6 +591,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( } auto indexType = typeConverter.getIndexType(builder.getContext()); + SmallVector linearizedIndices; // Add a '0' at the start to index into the struct. auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); @@ -606,7 +607,7 @@ spirv::AccessChainOp mlir::spirv::getElementPtr( loc, indexType, IntegerAttr::get(indexType, offset)); assert(indices.size() == strides.size() && "must provide indices for all dimensions"); - for (auto index : enumerate(indices)) { + for (auto index : llvm::enumerate(indices)) { Value strideVal = builder.create( loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); Value update = diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp index bfdcbdc344a..43bbe2420a9 100644 --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -120,6 +120,27 @@ void tblgen::OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { } //===----------------------------------------------------------------------===// +// OpConstructor definitions +//===----------------------------------------------------------------------===// + +void mlir::tblgen::OpConstructor::addMemberInitializer(StringRef name, + StringRef value) { + memberInitializers.append(std::string(llvm::formatv( + "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); +} + +void mlir::tblgen::OpConstructor::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { + if (isDeclOnly) + return; + + methodSignature.writeDefTo(os, namePrefix); + os << " " << memberInitializers << " {\n"; + methodBody.writeTo(os); + os << "}"; +} + +//===----------------------------------------------------------------------===// // Class definitions //===----------------------------------------------------------------------===// @@ -133,10 +154,11 @@ tblgen::OpMethod &tblgen::Class::newMethod(StringRef retType, StringRef name, return methods.back(); } -tblgen::OpMethod &tblgen::Class::newConstructor(StringRef params, - bool declOnly) { - return newMethod("", getClassName(), params, OpMethod::MP_Constructor, - declOnly); +tblgen::OpConstructor &tblgen::Class::newConstructor(StringRef params, + bool declOnly) { + constructors.emplace_back("", getClassName(), params, + OpMethod::MP_Constructor, declOnly); + return constructors.back(); } void tblgen::Class::newField(StringRef type, StringRef name, @@ -152,7 +174,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const { bool hasPrivateMethod = false; os << "class " << className << " {\n"; os << "public:\n"; - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { if (!method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -163,7 +186,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const { os << '\n'; os << "private:\n"; if (hasPrivateMethod) { - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { if (method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -177,7 +201,8 @@ void tblgen::Class::writeDeclTo(raw_ostream &os) const { } void tblgen::Class::writeDefTo(raw_ostream &os) const { - for (const auto &method : methods) { + for (const auto &method : + llvm::concat(constructors, methods)) { method.writeDefTo(os, className); os << "\n\n"; } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 2f77184980e..f575fedc1f2 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -59,6 +59,10 @@ std::string tblgen::Operator::getOperationName() const { return std::string(llvm::formatv("{0}.{1}", prefix, opName)); } +std::string tblgen::Operator::getAdaptorName() const { + return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName())); +} + StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } StringRef tblgen::Operator::getCppClassName() const { return cppClassName; } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 565f1921125..a101103b08f 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -49,14 +49,14 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // CHECK: class AOpOperandAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ArrayRef values -// CHECK: ArrayRef getODSOperands(unsigned index); +// CHECK: AOpOperandAdaptor(ValueRange values +// CHECK: ValueRange getODSOperands(unsigned index); // CHECK: Value a(); -// CHECK: ArrayRef b(); +// CHECK: ValueRange b(); // CHECK: IntegerAttr attr1(); // CHECL: FloatAttr attr2(); // CHECK: private: -// CHECK: ArrayRef odsOperands; +// CHECK: ValueRange odsOperands; // CHECK: }; // CHECK: class AOp : public Op::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove @@ -106,12 +106,12 @@ def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands", } // CHECK-LABEL: AttrSizedOperandOpOperandAdaptor( -// CHECK-SAME: ArrayRef values +// CHECK-SAME: ValueRange values // CHECK-SAME: DictionaryAttr attrs -// CHECK: ArrayRef a(); -// CHECK: ArrayRef b(); +// CHECK: ValueRange a(); +// CHECK: ValueRange b(); // CHECK: Value c(); -// CHECK: ArrayRef d(); +// CHECK: ValueRange d(); // CHECK: DenseIntElementsAttr operand_segment_sizes(); // Check op trait for different number of operands diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td index 5f0bfae9281..a9b61c179be 100644 --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -15,7 +15,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> { // CHECK-LABEL: OpA definitions // CHECK: OpAOperandAdaptor::OpAOperandAdaptor -// CHECK-NEXT: odsOperands = values +// CHECK-SAME: odsOperands(values), odsAttrs(attrs) // CHECK: void OpA::build // CHECK: Value input @@ -39,13 +39,13 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input1 +// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1 // CHECK-NEXT: return getODSOperands(0); // CHECK-LABEL: Value OpDOperandAdaptor::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); -// CHECK-LABEL: ArrayRef OpDOperandAdaptor::input3 +// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); // CHECK-LABEL: Operation::operand_range OpD::input1 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 0b55825d1a4..7b0cd9d7a48 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1890,27 +1890,38 @@ public: private: explicit OpOperandAdaptorEmitter(const Operator &op); - Class adapterClass; + Class adaptor; }; } // end namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) - : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { - adapterClass.newField("ArrayRef", "odsOperands"); - adapterClass.newField("DictionaryAttr", "odsAttrs"); + : adaptor(op.getAdaptorName()) { + adaptor.newField("ValueRange", "odsOperands"); + adaptor.newField("DictionaryAttr", "odsAttrs"); const auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); - auto &constructor = adapterClass.newConstructor( - attrSizedOperands - ? "ArrayRef values, DictionaryAttr attrs" - : "ArrayRef values, DictionaryAttr attrs = nullptr"); - constructor.body() << " odsOperands = values;\n"; - constructor.body() << " odsAttrs = attrs;\n"; + { + auto &constructor = adaptor.newConstructor( + attrSizedOperands + ? "ValueRange values, DictionaryAttr attrs" + : "ValueRange values, DictionaryAttr attrs = nullptr"); + constructor.addMemberInitializer("odsOperands", "values"); + constructor.addMemberInitializer("odsAttrs", "attrs"); + } + + { + auto &constructor = adaptor.newConstructor( + llvm::formatv("{0}& op", op.getCppClassName()).str()); + constructor.addMemberInitializer("odsOperands", + "op.getOperation()->getOperands()"); + constructor.addMemberInitializer("odsAttrs", + "op.getOperation()->getAttrDictionary()"); + } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); - generateNamedOperandGetters(op, adapterClass, sizeAttrInit, - /*rangeType=*/"ArrayRef", + generateNamedOperandGetters(op, adaptor, sizeAttrInit, + /*rangeType=*/"ValueRange", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", /*getOperandCallPattern=*/"odsOperands[{0}]"); @@ -1919,7 +1930,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) fctx.withBuilder("mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, Attribute attr) { - auto &body = adapterClass.newMethod(attr.getStorageType(), name).body(); + auto &body = adaptor.newMethod(attr.getStorageType(), name).body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" << "\n " << attr.getStorageType() << " attr = " << "odsAttrs.get(\"" << name << "\")."; @@ -1949,11 +1960,11 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adapterClass.writeDeclTo(os); + OpOperandAdaptorEmitter(op).adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) { - OpOperandAdaptorEmitter(op).adapterClass.writeDefTo(os); + OpOperandAdaptorEmitter(op).adaptor.writeDefTo(os); } // Emits the opcode enum and op classes. -- 2.11.0