From: Stella Laurenzo Date: Mon, 18 Jan 2021 19:27:19 +0000 (-0800) Subject: [mlir][python] Factor out standalone OpView._ods_build_default class method. X-Git-Url: http://git.osdn.net/view?a=commitdiff_plain;h=71b6b010e6bc49caaec511195e33ac1f43f07c64;p=android-x86%2Fexternal-llvm-project.git [mlir][python] Factor out standalone OpView._ods_build_default class method. * This allows us to hoist trait level information for regions and sized-variadic to class level attributes (_ODS_REGIONS, _ODS_OPERAND_SEGMENTS, _ODS_RESULT_SEGMENTS). * Eliminates some splicey python generated code in favor of a native helper for it. * Makes it possible to implement custom, variadic and region based builders with one line of python, without needing to manually code access to the segment attributes. * Needs follow-on work for region based callbacks and support for SingleBlockImplicitTerminator. * A follow-up will actually add ODS support for generating custom Python builders that delegate to this new method. * Also includes the start of an e2e sample for constructing linalg ops where this limitation was discovered (working progressively through this example and cleaning up as I go). Differential Revision: https://reviews.llvm.org/D94738 --- diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index b5595bc7010..6bb9e7ebe2f 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -365,7 +365,7 @@ for the canonical way to use this facility. Each dialect with a mapping to python requires that an appropriate `{DIALECT_NAMESPACE}.py` wrapper module is created. This is done by invoking -`mlir-tablegen` on a python-bindings specific tablegen wrapper that includes +`mlir-tblgen` on a python-bindings specific tablegen wrapper that includes the boilerplate and actual dialect specific `td` file. An example, for the `StandardOps` (which is assigned the namespace `std` as a special case): @@ -383,7 +383,7 @@ In the main repository, building the wrapper is done via the CMake function `add_mlir_dialect_python_bindings`, which invokes: ``` -mlir-tablegen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \ +mlir-tblgen -gen-python-op-bindings -bind-dialect={DIALECT_NAMESPACE} \ {PYTHON_BINDING_TD_FILE} ``` @@ -411,7 +411,8 @@ The wrapper module tablegen emitter outputs: Note: In order to avoid naming conflicts, all internal names used by the wrapper module are prefixed by `_ods_`. -Each concrete `OpView` subclass further defines several attributes: +Each concrete `OpView` subclass further defines several public-intended +attributes: * `OPERATION_NAME` attribute with the `str` fully qualified operation name (i.e. `std.absf`). @@ -421,6 +422,20 @@ Each concrete `OpView` subclass further defines several attributes: for unnamed of each). * `@property` getter, setter and deleter for each declared attribute. +It further emits additional private-intended attributes meant for subclassing +and customization (default cases omit these attributes in favor of the +defaults on `OpView`): + +* `_ODS_REGIONS`: A specification on the number and types of regions. + Currently a tuple of (min_region_count, has_no_variadic_regions). Note that + the API does some light validation on this but the primary purpose is to + capture sufficient information to perform other default building and region + accessor generation. +* `_ODS_OPERAND_SEGMENTS` and `_ODS_RESULT_SEGMENTS`: Black-box value which + indicates the structure of either the operand or results with respect to + variadics. Used by `OpView._ods_build_default` to decode operand and result + lists that contain lists. + #### Builders Presently, only a single, default builder is mapped to the `__init__` method. diff --git a/mlir/examples/python/linalg_matmul.py b/mlir/examples/python/linalg_matmul.py new file mode 100644 index 00000000000..83dc15eda9b --- /dev/null +++ b/mlir/examples/python/linalg_matmul.py @@ -0,0 +1,73 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This is a work in progress example to do end2end build and code generation +# of a small linalg program with configuration options. It is currently non +# functional and is being used to elaborate the APIs. + +from typing import Tuple + +from mlir.ir import * +from mlir.dialects import linalg +from mlir.dialects import std + + +# TODO: This should be in the core API. +def FuncOp(name: str, func_type: Type) -> Tuple[Operation, Block]: + """Creates a |func| op. + TODO: This should really be in the MLIR API. + Returns: + (operation, entry_block) + """ + attrs = { + "type": TypeAttr.get(func_type), + "sym_name": StringAttr.get(name), + } + op = Operation.create("func", regions=1, attributes=attrs) + body_region = op.regions[0] + entry_block = body_region.blocks.append(*func_type.inputs) + return op, entry_block + + +# TODO: Generate customs builder vs patching one in. +def PatchMatmulOpInit(self, lhs, rhs, result, loc=None, ip=None): + super(linalg.MatmulOp, self).__init__( + self._ods_build_default(operands=[[lhs, rhs], [result]], + results=[], + loc=loc, + ip=ip)) + # TODO: Implement support for SingleBlockImplicitTerminator + block = self.regions[0].blocks.append() + with InsertionPoint(block): + linalg.YieldOp(values=[]) + +linalg.MatmulOp.__init__ = PatchMatmulOpInit + + +def build_matmul_func(func_name, m, k, n, dtype): + lhs_type = MemRefType.get(dtype, [m, k]) + rhs_type = MemRefType.get(dtype, [k, n]) + result_type = MemRefType.get(dtype, [m, n]) + # TODO: There should be a one-liner for this. + func_type = FunctionType.get([lhs_type, rhs_type, result_type], []) + _, entry = FuncOp(func_name, func_type) + lhs, rhs, result = entry.arguments + with InsertionPoint(entry): + linalg.MatmulOp(lhs, rhs, result) + std.ReturnOp([]) + + +def run(): + with Context() as c, Location.unknown(): + module = Module.create() + # TODO: This at_block_terminator vs default construct distinction feels + # wrong and is error-prone. + with InsertionPoint.at_block_terminator(module.body): + build_matmul_func('main', 18, 32, 96, F32Type.get()) + + print(module) + print(module.operation.get_asm(print_generic_op_form=True)) + + +if __name__ == '__main__': run() diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 493ea5c1e47..63bdd0c7a18 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -130,6 +130,13 @@ equivalent to printing the operation that produced it. // Utilities. //------------------------------------------------------------------------------ +// Helper for creating an @classmethod. +template +py::object classmethod(Func f, Args... args) { + py::object cf = py::cpp_function(f, args...); + return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); +} + /// Checks whether the given type is an integer or float type. static int mlirTypeIsAIntegerOrFloat(MlirType type) { return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || @@ -1027,6 +1034,267 @@ py::object PyOperation::createOpView() { return py::cast(PyOpView(getRef().getObject())); } +//------------------------------------------------------------------------------ +// PyOpView +//------------------------------------------------------------------------------ + +py::object +PyOpView::odsBuildDefault(py::object cls, py::list operandList, + py::list resultTypeList, + llvm::Optional attributes, + llvm::Optional> successors, + llvm::Optional regions, + DefaultingPyLocation location, py::object maybeIp) { + PyMlirContextRef context = location->getContext(); + // Class level operation construction metadata. + std::string name = py::cast(cls.attr("OPERATION_NAME")); + // Operand and result segment specs are either none, which does no + // variadic unpacking, or a list of ints with segment sizes, where each + // element is either a positive number (typically 1 for a scalar) or -1 to + // indicate that it is derived from the length of the same-indexed operand + // or result (implying that it is a list at that position). + py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; + + // Validate/determine region count. + auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + int opMinRegionCount = std::get<0>(opRegionSpec); + bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); + if (!regions) { + regions = opMinRegionCount; + } + if (*regions < opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + if (opHasNoVariadicRegions && *regions > opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + + // Unpack results. + std::vector resultTypes; + resultTypes.reserve(resultTypeList.size()); + if (resultSegmentSpecObj.is_none()) { + // Non-variadic result unpacking. + for (auto it : llvm::enumerate(resultTypeList)) { + try { + resultTypes.push_back(py::cast(it.value())); + if (!resultTypes.back()) + throw py::cast_error(); + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Type (" + err.what() + ")") + .str()); + } + } + } else { + // Sized result unpacking. + auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + if (resultSegmentSpec.size() != resultTypeList.size()) { + throw py::value_error((llvm::Twine("Operation \"") + name + + "\" requires " + + llvm::Twine(resultSegmentSpec.size()) + + "result segments but was provided " + + llvm::Twine(resultTypeList.size())) + .str()); + } + resultSegmentLengths.reserve(resultTypeList.size()); + for (auto it : + llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1 || segmentSpec == 0) { + // Unpack unary element. + try { + auto resultType = py::cast(std::get<0>(it.value())); + if (resultType) { + resultTypes.push_back(resultType); + resultSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + resultSegmentLengths.push_back(0); + } else { + throw py::cast_error("was None and result is not optional"); + } + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Type (" + err.what() + + ")") + .str()); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + if (std::get<0>(it.value()).is_none()) { + // Treat it as an empty list. + resultSegmentLengths.push_back(0); + } else { + // Unpack the list. + auto segment = py::cast(std::get<0>(it.value())); + for (py::object segmentItem : segment) { + resultTypes.push_back(py::cast(segmentItem)); + if (!resultTypes.back()) { + throw py::cast_error("contained a None item"); + } + } + resultSegmentLengths.push_back(segment.size()); + } + } catch (std::exception &err) { + // NOTE: Sloppy to be using a catch-all here, but there are at least + // three different unrelated exceptions that can be thrown in the + // above "casts". Just keep the scope above small and catch them all. + throw py::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Sequence of Types (" + + err.what() + ")") + .str()); + } + } else { + throw py::value_error("Unexpected segment spec"); + } + } + } + + // Unpack operands. + std::vector operands; + operands.reserve(operands.size()); + if (operandSegmentSpecObj.is_none()) { + // Non-sized operand unpacking. + for (auto it : llvm::enumerate(operandList)) { + try { + operands.push_back(py::cast(it.value())); + if (!operands.back()) + throw py::cast_error(); + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Value (" + err.what() + ")") + .str()); + } + } + } else { + // Sized operand unpacking. + auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); + if (operandSegmentSpec.size() != operandList.size()) { + throw py::value_error((llvm::Twine("Operation \"") + name + + "\" requires " + + llvm::Twine(operandSegmentSpec.size()) + + "operand segments but was provided " + + llvm::Twine(operandList.size())) + .str()); + } + operandSegmentLengths.reserve(operandList.size()); + for (auto it : + llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1 || segmentSpec == 0) { + // Unpack unary element. + try { + auto operandValue = py::cast(std::get<0>(it.value())); + if (operandValue) { + operands.push_back(operandValue); + operandSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + operandSegmentLengths.push_back(0); + } else { + throw py::cast_error("was None and operand is not optional"); + } + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Value (" + err.what() + + ")") + .str()); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + if (std::get<0>(it.value()).is_none()) { + // Treat it as an empty list. + operandSegmentLengths.push_back(0); + } else { + // Unpack the list. + auto segment = py::cast(std::get<0>(it.value())); + for (py::object segmentItem : segment) { + operands.push_back(py::cast(segmentItem)); + if (!operands.back()) { + throw py::cast_error("contained a None item"); + } + } + operandSegmentLengths.push_back(segment.size()); + } + } catch (std::exception &err) { + // NOTE: Sloppy to be using a catch-all here, but there are at least + // three different unrelated exceptions that can be thrown in the + // above "casts". Just keep the scope above small and catch them all. + throw py::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Sequence of Values (" + + err.what() + ")") + .str()); + } + } else { + throw py::value_error("Unexpected segment spec"); + } + } + } + + // Merge operand/result segment lengths into attributes if needed. + if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { + // Dup. + if (attributes) { + attributes = py::dict(*attributes); + } else { + attributes = py::dict(); + } + if (attributes->contains("result_segment_sizes") || + attributes->contains("operand_segment_sizes")) { + throw py::value_error("Manually setting a 'result_segment_sizes' or " + "'operand_segment_sizes' attribute is unsupported. " + "Use Operation.create for such low-level access."); + } + + // Add result_segment_sizes attribute. + if (!resultSegmentLengths.empty()) { + int64_t size = resultSegmentLengths.size(); + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( + mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), + resultSegmentLengths.size(), resultSegmentLengths.data()); + (*attributes)["result_segment_sizes"] = + PyAttribute(context, segmentLengthAttr); + } + + // Add operand_segment_sizes attribute. + if (!operandSegmentLengths.empty()) { + int64_t size = operandSegmentLengths.size(); + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( + mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), + operandSegmentLengths.size(), operandSegmentLengths.data()); + (*attributes)["operand_segment_sizes"] = + PyAttribute(context, segmentLengthAttr); + } + } + + // Delegate to create. + return PyOperation::create(std::move(name), /*operands=*/std::move(operands), + /*results=*/std::move(resultTypes), + /*attributes=*/std::move(attributes), + /*successors=*/std::move(successors), + /*regions=*/*regions, location, maybeIp); +} + PyOpView::PyOpView(py::object operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. @@ -3397,17 +3665,29 @@ void mlir::python::populateIRSubmodule(py::module &m) { "Context that owns the Operation") .def_property_readonly("opview", &PyOperation::createOpView); - py::class_(m, "OpView") - .def(py::init()) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly( - "context", - [](PyOpView &self) { - return self.getOperation().getContext().getObject(); - }, - "Context that owns the Operation") - .def("__str__", - [](PyOpView &self) { return py::str(self.getOperationObject()); }); + auto opViewClass = + py::class_(m, "OpView") + .def(py::init()) + .def_property_readonly("operation", &PyOpView::getOperationObject) + .def_property_readonly( + "context", + [](PyOpView &self) { + return self.getOperation().getContext().getObject(); + }, + "Context that owns the Operation") + .def("__str__", [](PyOpView &self) { + return py::str(self.getOperationObject()); + }); + opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); + opViewClass.attr("_ods_build_default") = classmethod( + &PyOpView::odsBuildDefault, py::arg("cls"), + py::arg("operands") = py::none(), py::arg("results") = py::none(), + py::arg("attributes") = py::none(), py::arg("successors") = py::none(), + py::arg("regions") = py::none(), py::arg("loc") = py::none(), + py::arg("ip") = py::none(), + "Builds a specific, generated OpView based on class level attributes."); //---------------------------------------------------------------------------- // Mapping of PyRegion. diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h index e789f536a82..443cdd69186 100644 --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -497,6 +497,14 @@ public: pybind11::object getOperationObject() { return operationObject; } + static pybind11::object + odsBuildDefault(pybind11::object cls, pybind11::list operandList, + pybind11::list resultTypeList, + llvm::Optional attributes, + llvm::Optional> successors, + llvm::Optional regions, DefaultingPyLocation location, + pybind11::object maybeIp); + private: PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. diff --git a/mlir/test/Bindings/Python/ods_helpers.py b/mlir/test/Bindings/Python/ods_helpers.py new file mode 100644 index 00000000000..1db1112c408 --- /dev/null +++ b/mlir/test/Bindings/Python/ods_helpers.py @@ -0,0 +1,210 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +def add_dummy_value(): + return Operation.create( + "custom.value", + results=[IntegerType.get_signless(32)]).result + + +def testOdsBuildDefaultImplicitRegions(): + + class TestFixedRegionsOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_REGIONS = (2, True) + + class TestVariadicRegionsOp(OpView): + OPERATION_NAME = "custom.test_any_regions_op" + _ODS_REGIONS = (2, False) + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + op = TestFixedRegionsOp._ods_build_default(operands=[], results=[]) + # CHECK: NUM_REGIONS: 2 + print(f"NUM_REGIONS: {len(op.regions)}") + # Including a regions= that matches should be fine. + op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=2) + print(f"NUM_REGIONS: {len(op.regions)}") + # Reject greater than. + try: + op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=3) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_op" requires a maximum of 2 regions but was built with regions=3 + print(f"ERROR:{e}") + # Reject less than. + try: + op = TestFixedRegionsOp._ods_build_default(operands=[], results=[], regions=1) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_op" requires a minimum of 2 regions but was built with regions=1 + print(f"ERROR:{e}") + + # If no regions specified for a variadic region op, build the minimum. + op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[]) + # CHECK: DEFAULT_NUM_REGIONS: 2 + print(f"DEFAULT_NUM_REGIONS: {len(op.regions)}") + # Should also accept an explicit regions= that matches the minimum. + op = TestVariadicRegionsOp._ods_build_default( + operands=[], results=[], regions=2) + # CHECK: EQ_NUM_REGIONS: 2 + print(f"EQ_NUM_REGIONS: {len(op.regions)}") + # And accept greater than minimum. + # Should also accept an explicit regions= that matches the minimum. + op = TestVariadicRegionsOp._ods_build_default( + operands=[], results=[], regions=3) + # CHECK: GT_NUM_REGIONS: 3 + print(f"GT_NUM_REGIONS: {len(op.regions)}") + # Should reject less than minimum. + try: + op = TestVariadicRegionsOp._ods_build_default(operands=[], results=[], regions=1) + except ValueError as e: + # CHECK: ERROR:Operation "custom.test_any_regions_op" requires a minimum of 2 regions but was built with regions=1 + print(f"ERROR:{e}") + + + +run(testOdsBuildDefaultImplicitRegions) + + +def testOdsBuildDefaultNonVariadic(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + op = TestOp._ods_build_default(operands=[v0, v1], results=[t0, t1]) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]]) + # CHECK-NOT: operand_segment_sizes + # CHECK-NOT: result_segment_sizes + # CHECK-SAME: : (i32, i32) -> (i8, i16) + print(m) + +run(testOdsBuildDefaultNonVariadic) + + +def testOdsBuildDefaultSizedVariadic(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + _ODS_OPERAND_SEGMENTS = [1, -1, 0] + _ODS_RESULT_SEGMENTS = [-1, 0, 1] + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + v2 = add_dummy_value() + v3 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + t2 = IntegerType.get_signless(32) + t3 = IntegerType.get_signless(64) + # CHECK: %[[V0:.+]] = "custom.value" + # CHECK: %[[V1:.+]] = "custom.value" + # CHECK: %[[V2:.+]] = "custom.value" + # CHECK: %[[V3:.+]] = "custom.value" + # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) + # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64> + # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64> + # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64) + op = TestOp._ods_build_default( + operands=[v0, [v1, v2], v3], + results=[[t0, t1], t2, t3]) + + # Now test with optional omitted. + # CHECK: "custom.test_op"(%[[V0]]) + # CHECK-SAME: operand_segment_sizes = dense<[1, 0, 0]> + # CHECK-SAME: result_segment_sizes = dense<[0, 0, 1]> + # CHECK-SAME: (i32) -> i64 + op = TestOp._ods_build_default( + operands=[v0, None, None], + results=[None, None, t3]) + print(m) + + # And verify that errors are raised for None in a required operand. + try: + op = TestOp._ods_build_default( + operands=[None, None, None], + results=[None, None, t3]) + except ValueError as e: + # CHECK: OPERAND_CAST_ERROR:Operand 0 of operation "custom.test_op" must be a Value (was None and operand is not optional) + print(f"OPERAND_CAST_ERROR:{e}") + + # And verify that errors are raised for None in a required result. + try: + op = TestOp._ods_build_default( + operands=[v0, None, None], + results=[None, None, None]) + except ValueError as e: + # CHECK: RESULT_CAST_ERROR:Result 2 of operation "custom.test_op" must be a Type (was None and result is not optional) + print(f"RESULT_CAST_ERROR:{e}") + + # Variadic lists with None elements should reject. + try: + op = TestOp._ods_build_default( + operands=[v0, [None], None], + results=[None, None, t3]) + except ValueError as e: + # CHECK: OPERAND_LIST_CAST_ERROR:Operand 1 of operation "custom.test_op" must be a Sequence of Values (contained a None item) + print(f"OPERAND_LIST_CAST_ERROR:{e}") + try: + op = TestOp._ods_build_default( + operands=[v0, None, None], + results=[[None], None, t3]) + except ValueError as e: + # CHECK: RESULT_LIST_CAST_ERROR:Result 0 of operation "custom.test_op" must be a Sequence of Types (contained a None item) + print(f"RESULT_LIST_CAST_ERROR:{e}") + +run(testOdsBuildDefaultSizedVariadic) + + +def testOdsBuildDefaultCastError(): + + class TestOp(OpView): + OPERATION_NAME = "custom.test_op" + + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + m = Module.create() + with InsertionPoint.at_block_terminator(m.body): + v0 = add_dummy_value() + v1 = add_dummy_value() + t0 = IntegerType.get_signless(8) + t1 = IntegerType.get_signless(16) + try: + op = TestOp._ods_build_default( + operands=[None, v1], + results=[t0, t1]) + except ValueError as e: + # CHECK: ERROR: Operand 0 of operation "custom.test_op" must be a Value + print(f"ERROR: {e}") + try: + op = TestOp._ods_build_default( + operands=[v0, v1], + results=[t0, None]) + except ValueError as e: + # CHECK: Result 1 of operation "custom.test_op" must be a Type + print(f"ERROR: {e}") + +run(testOdsBuildDefaultCastError) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 722cf9fb7e4..235cb4a1fa5 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -17,23 +17,18 @@ class TestOp traits = []> : // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedOperandsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands" +// CHECK: _ODS_OPERAND_SEGMENTS = [-1,1,-1,] def AttrSizedOperandsOp : TestOp<"attr_sized_operands", [AttrSizedOperandSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: operand_segment_sizes_ods = _ods_array.array('L') - // CHECK: operands += [*variadic1] - // CHECK: operand_segment_sizes_ods.append(len(variadic1)) + // CHECK: operands.append(variadic1) // CHECK: operands.append(non_variadic) - // CHECK: operand_segment_sizes_ods.append(1) // CHECK: if variadic2 is not None: operands.append(variadic2) - // CHECK: operand_segment_sizes_ods.append(0 if variadic2 is None else 1) - // CHECK: attributes["operand_segment_sizes"] = _ods_ir.DenseElementsAttr.get(operand_segment_sizes_ods, - // CHECK: context=_ods_get_default_loc_context(loc)) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attr_sized_operands", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -63,23 +58,18 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results" +// CHECK: _ODS_RESULT_SEGMENTS = [-1,1,-1,] def AttrSizedResultsOp : TestOp<"attr_sized_results", [AttrSizedResultSegments]> { // CHECK: def __init__(self, variadic1, non_variadic, variadic2, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: result_segment_sizes_ods = _ods_array.array('L') // CHECK: if variadic1 is not None: results.append(variadic1) - // CHECK: result_segment_sizes_ods.append(0 if variadic1 is None else 1) // CHECK: results.append(non_variadic) - // CHECK: result_segment_sizes_ods.append(1) # non_variadic // CHECK: if variadic2 is not None: results.append(variadic2) - // CHECK: result_segment_sizes_ods.append(0 if variadic2 is None else 1) - // CHECK: attributes["result_segment_sizes"] = _ods_ir.DenseElementsAttr.get(result_segment_sizes_ods, - // CHECK: context=_ods_get_default_loc_context(loc)) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attr_sized_results", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -110,6 +100,8 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attributed_op" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOp : TestOp<"attributed_op"> { // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None): // CHECK: operands = [] @@ -120,8 +112,8 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: attributes["in"] = in_ - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -148,6 +140,8 @@ def AttributedOp : TestOp<"attributed_op"> { // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None): // CHECK: operands = [] @@ -158,8 +152,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { // CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get( // CHECK: _ods_get_default_loc_context(loc)) // CHECK: if is_ is not None: attributes["is"] = is_ - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -183,8 +177,8 @@ def EmptyOp : TestOp<"empty">; // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.empty", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -201,8 +195,8 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: operands.append(_gen_arg_0) // CHECK: operands.append(f32) // CHECK: operands.append(_gen_arg_2) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.missing_names", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -223,15 +217,17 @@ def MissingNamesOp : TestOp<"missing_names"> { // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: def __init__(self, non_variadic, variadic, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(non_variadic) - // CHECK: operands += [*variadic] - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.one_variadic_operand", attributes=attributes, operands=operands, results=results, + // CHECK: operands.extend(variadic) + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -248,15 +244,17 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result" +// CHECK-NOT: _ODS_OPERAND_SEGMENTS +// CHECK-NOT: _ODS_RESULT_SEGMENTS def OneVariadicResultOp : TestOp<"one_variadic_result"> { // CHECK: def __init__(self, variadic, non_variadic, loc=None, ip=None): // CHECK: operands = [] // CHECK: results = [] // CHECK: attributes = {} - // CHECK: results += [*variadic] + // CHECK: results.extend(variadic) // CHECK: results.append(non_variadic) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.one_variadic_result", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -280,8 +278,8 @@ def PythonKeywordOp : TestOp<"python_keyword"> { // CHECK: results = [] // CHECK: attributes = {} // CHECK: operands.append(in_) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.python_keyword", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property @@ -348,8 +346,8 @@ def SimpleOp : TestOp<"simple"> { // CHECK: results.append(f64) // CHECK: operands.append(i32) // CHECK: operands.append(f32) - // CHECK: super().__init__(_ods_ir.Operation.create( - // CHECK: "test.simple", attributes=attributes, operands=operands, results=results, + // CHECK: super().__init__(self._ods_build_default( + // CHECK: attributes=attributes, operands=operands, results=results, // CHECK: loc=loc, ip=ip)) // CHECK: @property diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 16bf6d1dc03..658ad75eea2 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -26,7 +26,6 @@ using namespace mlir::tblgen; constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. -import array as _ods_array from . import _cext as _ods_cext from . import _segmented_accessor as _ods_segmented_accessor, _equally_sized_accessor as _ods_equally_sized_accessor, _get_default_loc_context as _ods_get_default_loc_context _ods_ir = _ods_cext.ir @@ -51,6 +50,25 @@ class {0}(_ods_ir.OpView): OPERATION_NAME = "{1}" )Py"; +/// Template for class level declarations of operand and result +/// segment specs. +/// {0} is either "OPERAND" or "RESULT" +/// {1} is the segment spec +/// Each segment spec is either None (default) or an array of integers +/// where: +/// 1 = single element (expect non sequence operand/result) +/// -1 = operand/result is a sequence corresponding to a variadic +constexpr const char *opClassSizedSegmentsTemplate = R"Py( + _ODS_{0}_SEGMENTS = {1} +)Py"; + +/// Template for class level declarations of the _ODS_REGIONS spec: +/// {0} is the minimum number of regions +/// {1} is the Python bool literal for hasNoVariadicRegions +constexpr const char *opClassRegionSpecTemplate = R"Py( + _ODS_REGIONS = ({0}, {1}) +)Py"; + /// Template for single-element accessor: /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; @@ -446,18 +464,17 @@ static void emitAttributeAccessors(const Operator &op, } /// Template for the default auto-generated builder. -/// {0} is the operation name; -/// {1} is a comma-separated list of builder arguments, including the trailing +/// {0} is a comma-separated list of builder arguments, including the trailing /// `loc` and `ip`; -/// {2} is the code populating `operands`, `results` and `attributes` fields. +/// {1} is the code populating `operands`, `results` and `attributes` fields. constexpr const char *initTemplate = R"Py( - def __init__(self, {1}): + def __init__(self, {0}): operands = [] results = [] attributes = {{} - {2} - super().__init__(_ods_ir.Operation.create( - "{0}", attributes=attributes, operands=operands, results=results, + {1} + super().__init__(self._ods_build_default( + attributes=attributes, operands=operands, results=results, loc=loc, ip=ip)) )Py"; @@ -472,37 +489,10 @@ constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; constexpr const char *optionalAppendTemplate = "if {1} is not None: {0}s.append({1})"; -/// Template for appending a variadic element to the operand/result list. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *variadicAppendTemplate = "{0}s += [*{1}]"; - -/// Template for setting up the segment sizes buffer. -constexpr const char *segmentDeclarationTemplate = - "{0}_segment_sizes_ods = _ods_array.array('L')"; - -/// Template for attaching segment sizes to the attribute list. -constexpr const char *segmentAttributeTemplate = - R"Py(attributes["{0}_segment_sizes"] = _ods_ir.DenseElementsAttr.get({0}_segment_sizes_ods, - context=_ods_get_default_loc_context(loc)))Py"; - -/// Template for appending the unit size to the segment sizes. +/// Template for appending a a list of elements to the operand/result list. /// {0} is either 'operand' or 'result'; /// {1} is the field name. -constexpr const char *singleElementSegmentTemplate = - "{0}_segment_sizes_ods.append(1) # {1}"; - -/// Template for appending 0/1 for an optional element to the segment sizes. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *optionalSegmentTemplate = - "{0}_segment_sizes_ods.append(0 if {1} is None else 1)"; - -/// Template for appending the length of a variadic group to the segment sizes. -/// {0} is either 'operand' or 'result'; -/// {1} is the field name. -constexpr const char *variadicSegmentTemplate = - "{0}_segment_sizes_ods.append(len({1}))"; +constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; /// Template for setting an attribute in the operation builder. /// {0} is the attribute name; @@ -584,11 +574,7 @@ static void populateBuilderLines( llvm::function_ref getNumElements, llvm::function_ref getElement) { - // The segment sizes buffer only has to be populated if there attr-sized - // segments trait is present. - bool includeSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; - if (includeSegments) - builderLines.push_back(llvm::formatv(segmentDeclarationTemplate, kind)); + bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; // For each element, find or generate a name. for (int i = 0, e = getNumElements(op); i < e; ++i) { @@ -596,28 +582,28 @@ static void populateBuilderLines( std::string name = names[i]; // Choose the formatting string based on the element kind. - llvm::StringRef formatString, segmentFormatString; + llvm::StringRef formatString; if (!element.isVariableLength()) { formatString = singleElementAppendTemplate; - segmentFormatString = singleElementSegmentTemplate; } else if (element.isOptional()) { formatString = optionalAppendTemplate; - segmentFormatString = optionalSegmentTemplate; } else { assert(element.isVariadic() && "unhandled element group type"); - formatString = variadicAppendTemplate; - segmentFormatString = variadicSegmentTemplate; + // If emitting with sizedSegments, then we add the actual list typed + // element using the singleElementAppendTemplate. Otherwise, we extend + // the actual operands. + if (sizedSegments) { + // Append the list as is. + formatString = singleElementAppendTemplate; + } else { + // Append the list elements. + formatString = multiElementAppendTemplate; + } } // Add the lines. builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); - if (includeSegments) - builderLines.push_back( - llvm::formatv(segmentFormatString.data(), kind, name)); } - - if (includeSegments) - builderLines.push_back(llvm::formatv(segmentAttributeTemplate, kind)); } /// Emits a default builder constructing an operation from the list of its @@ -645,8 +631,7 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { builderArgs.push_back("loc=None"); builderArgs.push_back("ip=None"); - os << llvm::formatv(initTemplate, op.getOperationName(), - llvm::join(builderArgs, ", "), + os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), llvm::join(builderLines, "\n ")); } @@ -659,12 +644,52 @@ static void constructAttributeMapping(const llvm::RecordKeeper &records, } } +static void emitSegmentSpec( + const Operator &op, const char *kind, + llvm::function_ref getNumElements, + llvm::function_ref + getElement, + raw_ostream &os) { + std::string segmentSpec("["); + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + if (element.isVariableLength()) { + segmentSpec.append("-1,"); + } else if (element.isOptional()) { + segmentSpec.append("0,"); + } else { + segmentSpec.append("1,"); + } + } + segmentSpec.append("]"); + + os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); +} + +static void emitRegionAttributes(const Operator &op, raw_ostream &os) { + // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). + // Note that the base OpView class defines this as (0, True). + unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); + os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, + op.hasNoVariadicRegions() ? "True" : "False"); +} + /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, const AttributeClasses &attributeClasses, raw_ostream &os) { os << llvm::formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); + + // Sized segments. + if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { + emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); + } + if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { + emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); + } + + emitRegionAttributes(op, os); emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); emitAttributeAccessors(op, attributeClasses, os);