From fdaa391e3df3c3a555d933122b0ef85eaf5eb63c Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 28 May 2020 14:04:39 +0000 Subject: [PATCH] [MLIR] Add `num_elements` to the shape dialect The operation `num_elements` determines the number of elements for a given shape. That is the product of its dimensions. Differential Revision: https://reviews.llvm.org/D80281 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 18 ++++++++++++++++++ mlir/lib/Dialect/Shape/IR/Shape.cpp | 26 ++++++++++++++++++++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 22 ++++++++++++++++++++++ 3 files changed, 66 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 57d1954a319..0d300d3c64c 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -278,6 +278,24 @@ def Shape_MulOp : Shape_Op<"mul", [SameOperandsAndResultType]> { let results = (outs Shape_SizeType:$result); } +def Shape_NumElementsOp : Shape_Op<"num_elements", [ + NoSideEffect, + DeclareOpInterfaceMethods]> { + + let summary = "Returns the number of elements for a given shape"; + let description = [{ + Returns the number of elements for a given shape which is the product of its + dimensions. + }]; + + let arguments = (ins Shape_ShapeType:$shape); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $shape"; + + let hasFolder = 1; +} + def Shape_ReduceOp : Shape_Op<"reduce", []> { let summary = "Returns an expression reduced over a shape"; let description = [{ diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index a077948fdd3..b0103e15fa3 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -333,6 +333,32 @@ OpFoldResult GetExtentOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// NumElementsOp +//===----------------------------------------------------------------------===// + +OpFoldResult NumElementsOp::fold(ArrayRef operands) { + + // Fold only when argument constant. + Attribute shape = operands[0]; + if (!shape) + return {}; + + APInt product(64, 1); + for (auto value : shape.cast()) + product *= value; + Builder builder(getContext()); + return builder.getIndexAttr(product.getLimitedValue()); +} + +LogicalResult NumElementsOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 106171de608..69c312e6dad 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -161,6 +161,28 @@ func @nonfoldable_index_to_size(%ci : index) -> !shape.size { } // ----- +// Fold number of elements computation. +// CHECK-LABEL: func @num_elements +func @num_elements() -> !shape.size { + // CHECK-NOT: shape.const_shape + %shape = shape.const_shape [4, 5, 6] + // CHECK-NOT: shape.num_elements + %num_elements = shape.num_elements %shape + // CHECK: %[[NUM:.*]] = shape.const_size 120 + // CHECK-NEXT: return %[[NUM]] : !shape.size + return %num_elements : !shape.size +} + +// ----- +// No folding. +// CHECK-LABEL: func @nonfoldable_num_elements +func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { + // CHECK-NOT: shape.const_{{.*}} + %num_elements = shape.num_elements %shape + return %num_elements : !shape.size +} + +// ----- // Canonicalization of shape.get_extent -- 2.11.0