From: Sean Silva Date: Tue, 26 May 2020 23:45:32 +0000 (-0700) Subject: [mlir][shape] Use IndexElementsAttr in Shape dialect. X-Git-Url: http://git.osdn.net/view?a=commitdiff_plain;h=25132b36a8b39e7c2b0b28aa73772e57191b6df4;p=android-x86%2Fexternal-llvm-project.git [mlir][shape] Use IndexElementsAttr in Shape dialect. Summary: Index is the proper type for storing shapes when constant folding, so this fixes the previous code (which was using i64). Differential Revision: https://reviews.llvm.org/D80600 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index a9759fc6a73..406aac2db99 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -102,7 +102,7 @@ def Shape_ConstShapeOp : Shape_Op<"const_shape", [ConstantLike, NoSideEffect]> { %1 = shape.const_shape [1, 2, 3] ``` }]; - let arguments = (ins I64ElementsAttr:$shape); + let arguments = (ins IndexElementsAttr:$shape); let results = (outs Shape_ShapeType:$result); // TODO: Move this to main so that all shape ops implement these. @@ -206,13 +206,8 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", let builders = [ // Builder that allows passing a simple integer instead of an IntegerAttr. OpBuilder< - [{ - OpBuilder &builder, OperationState &result, - Value shape, int64_t dim - }], - [{ - build(builder, result, shape, builder.getI64IntegerAttr(dim)); - }] + [{OpBuilder &builder, OperationState &result, Value shape, int64_t dim}], + [{build(builder, result, shape, builder.getI64IntegerAttr(dim));}] > ]; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index fa9552fc869..c4a8b152981 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -177,7 +177,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef operands) { if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape)) return nullptr; Builder builder(getContext()); - return builder.getI64TensorAttr(resultShape); + return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// @@ -215,7 +215,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser, ints.push_back(attr.getInt()); } Builder &builder = parser.getBuilder(); - result.addAttribute("shape", builder.getI64TensorAttr(ints)); + result.addAttribute("shape", builder.getIndexTensorAttr(ints)); result.types.push_back(ShapeType::get(builder.getContext())); return success(); @@ -257,7 +257,7 @@ OpFoldResult FromExtentsOp::fold(ArrayRef operands) { for (auto attr : operands) extents.push_back(attr.cast().getInt()); Builder builder(getContext()); - return builder.getI64TensorAttr(extents); + return builder.getIndexTensorAttr(extents); } //===----------------------------------------------------------------------===// @@ -281,14 +281,7 @@ OpFoldResult GetExtentOp::fold(ArrayRef operands) { // TODO: Constant fold this to some kind of constant error. if (dimToGet >= (uint64_t)elements.getNumElements()) return nullptr; - // This is a little inconvenient because getValue returns an IntegerAttr - // that is not of IndexType, but the result here needs to be of - // IndexType. - // TODO: Make ConstShapeOp hold an tensor of index instead of i64. - Builder builder(getContext()); - return builder.getIntegerAttr( - builder.getIndexType(), - elements.getValue({dimToGet}).getInt()); + return elements.getValue({dimToGet}); } //===----------------------------------------------------------------------===// @@ -309,7 +302,7 @@ OpFoldResult ShapeOfOp::fold(ArrayRef) { if (!type || !type.hasStaticShape()) return nullptr; Builder builder(getContext()); - return builder.getI64TensorAttr(type.getShape()); + return builder.getIndexTensorAttr(type.getShape()); } //===----------------------------------------------------------------------===// @@ -343,8 +336,8 @@ LogicalResult SplitAtOp::fold(ArrayRef operands, if (splitPoint < 0) splitPoint += shape.size(); Builder builder(operands[0].getContext()); - results.push_back(builder.getI64TensorAttr(shape.take_front(splitPoint))); - results.push_back(builder.getI64TensorAttr(shape.drop_front(splitPoint))); + results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint))); + results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint))); return success(); } @@ -373,7 +366,7 @@ OpFoldResult ConcatOp::fold(ArrayRef operands) { resultShape.append(lhsShape.begin(), lhsShape.end()); resultShape.append(rhsShape.begin(), rhsShape.end()); Builder builder(getContext()); - return builder.getI64TensorAttr(resultShape); + return builder.getIndexTensorAttr(resultShape); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 018f5b212b4..23147e557a1 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -15,7 +15,7 @@ func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.const_shape [2, 3] // CHECK: shape.const_shape [4, 5] %c2 = constant 2 : i32 - %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %0 = shape.const_shape [2, 3, 4, 5] %head, %tail = "shape.split_at"(%0, %c2) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape @@ -28,7 +28,7 @@ func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.const_shape [2, 3, 4] // CHECK: shape.const_shape [5] %c-1 = constant -1 : i32 - %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %0 = shape.const_shape [2, 3, 4, 5] %head, %tail = "shape.split_at"(%0, %c-1) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape } @@ -39,7 +39,7 @@ func @f() -> (!shape.shape, !shape.shape) { func @f() -> (!shape.shape, !shape.shape) { // CHECK: shape.split_at %c5 = constant 5 : i32 - %0 = "shape.const_shape"() {shape = dense<[2, 3, 4, 5]> : tensor<4xi64>} : () -> !shape.shape + %0 = shape.const_shape [2, 3, 4, 5] %head, %tail = "shape.split_at"(%0, %c5) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) return %head, %tail : !shape.shape, !shape.shape }