OSDN Git Service

[MLIR] Add `index_to_size` and `size_to_index` to the shape dialect
authorFrederik Gossen <frgossen@google.com>
Thu, 28 May 2020 13:55:02 +0000 (13:55 +0000)
committerFrederik Gossen <frgossen@google.com>
Thu, 28 May 2020 13:57:20 +0000 (13:57 +0000)
Add the two conversion operations `index_to_size` and `size_to_index` to the
shape dialect.
This facilitates the conversion of index types between the shape and the
standard dialect.

Differential Revision: https://reviews.llvm.org/D80280

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir

index dddc4c3..57d1954 100644 (file)
@@ -214,6 +214,25 @@ def Shape_GetExtentOp : Shape_Op<"get_extent",
   let hasFolder = 1;
 }
 
+def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [
+    NoSideEffect,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Converts a standard index to a shape size";
+  let description = [{
+    Converts a standard index to a `shape.size`.
+    This operation and its inverse, `size_to_index`, facilitate index conversion
+    between the standard and the shape dialect.
+    The behavior is undefined for negative indices.
+  }];
+
+  let arguments = (ins Index:$arg);
+  let results = (outs Shape_SizeType:$result);
+
+  let assemblyFormat = "attr-dict $arg";
+
+  let hasFolder = 1;
+}
+
 def Shape_JoinOp : Shape_Op<"join", []> {
   let summary = "Returns the least general shape.size of its operands";
   let description = [{
@@ -312,6 +331,25 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of",
   let hasFolder = 1;
 }
 
+def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [
+    NoSideEffect,
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+  let summary = "Casts between index types of the shape and standard dialect";
+  let description = [{
+    Converts a `shape.size` to a standard index.
+    This operation and its inverse, `index_to_size`, facilitate index conversion
+    between the standard and the shape dialect.
+    The behavior is undefined for unknown and invalid arguments.
+  }];
+
+  let arguments = (ins Shape_SizeType:$arg);
+  let results = (outs Index:$result);
+
+  let assemblyFormat = "attr-dict $arg";
+
+  let hasFolder = 1;
+}
+
 def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
   let summary = "Returns the value to parent op";
 
@@ -523,7 +561,6 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
   let assemblyFormat = "$inputs attr-dict";
 }
 
-
 // Canonicalization patterns.
 
 #endif // SHAPE_OPS
index fc8f9b2..a077948 100644 (file)
@@ -249,7 +249,7 @@ static ParseResult parseConstShapeOp(OpAsmParser &parser,
   return success();
 }
 
-OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
+OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
 
 //===----------------------------------------------------------------------===//
 // ConstSizeOp
@@ -267,6 +267,26 @@ ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
 OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
 
 //===----------------------------------------------------------------------===//
+// IndexToSizeOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult IndexToSizeOp::fold(ArrayRef<Attribute> operands) {
+  // Constant values of both types, `shape.size` and `index`, are represented as
+  // `IntegerAttr`s which makes constant folding simple.
+  if (Attribute arg = operands[0])
+    return arg;
+  return {};
+}
+
+LogicalResult IndexToSizeOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(SizeType::get(context));
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // FromExtentsOp
 //===----------------------------------------------------------------------===//
 
@@ -334,6 +354,26 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
 }
 
 //===----------------------------------------------------------------------===//
+// SizeToIndexOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
+  // Constant values of both types, `shape.size` and `index`, are represented as
+  // `IntegerAttr`s which makes constant folding simple.
+  if (Attribute arg = operands[0])
+    return arg;
+  return {};
+}
+
+LogicalResult SizeToIndexOp::inferReturnTypes(
+    MLIRContext *context, Optional<Location> location, ValueRange operands,
+    DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  inferredReturnTypes.push_back(IndexType::get(context));
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // SplitAtOp
 //===----------------------------------------------------------------------===//
 
index 23147e5..106171d 100644 (file)
@@ -108,6 +108,60 @@ func @no_fold(%arg0: index) -> !shape.shape {
 }
 
 // -----
+// Cast constant size to index and fold it away.
+// CHECK-LABEL: func @const_size_to_index
+func @const_size_to_index() -> index {
+  // CHECK-NOT: shape.index_cast
+  %cs = shape.const_size 123
+  // CHECK: constant 123 : index
+  %ci = shape.size_to_index %cs
+  return %ci : index
+}
+
+// -----
+// Cast constant index to size and fold it away.
+// CHECK-LABEL: func @const_index_to_size
+func @const_index_to_size() -> !shape.size {
+  // CHECK-NOT: index_cast
+  %ci = constant 123 : index
+  // CHECK: shape.const_size 123
+  %cs = shape.index_to_size %ci
+  return %cs : !shape.size
+}
+
+// -----
+// Cast constant index to size, then back, and fold it away.
+// CHECK-LABEL: func @const_index_to_size_to_index
+func @const_index_to_size_to_index() -> index {
+  // CHECK-NOT: shape.index_cast
+  %ci0 = constant 123 : index
+  %cs0 = shape.index_to_size %ci0
+  // CHECK: %[[CI:.*]] = constant 123 : index
+  // CHECK-NEXT: return %[[CI]] : index
+  %ci1 = shape.size_to_index %cs0
+  return %ci1 : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_size_to_index
+func @nonfoldable_size_to_index(%cs : !shape.size) -> index {
+  // CHECK: shape.size_to_index
+  %ci = shape.size_to_index %cs
+  return %ci : index
+}
+
+// -----
+// No folding.
+// CHECK-LABEL: func @nonfoldable_index_to_size
+func @nonfoldable_index_to_size(%ci : index) -> !shape.size {
+  // CHECK: shape.index_to_size
+  %cs = shape.index_to_size %ci
+  return %cs : !shape.size
+}
+
+// -----
+
 // Canonicalization of shape.get_extent
 
 // Basic folding.