OSDN Git Service

[mlir][vector] generalized masked l/s and compressed l/s with indices
authorAart Bik <ajcbik@google.com>
Fri, 8 Jan 2021 18:26:57 +0000 (10:26 -0800)
committerAart Bik <ajcbik@google.com>
Fri, 8 Jan 2021 21:59:34 +0000 (13:59 -0800)
Adding the ability to index the base address brings these operations closer
to the transfer read and write semantics (with lowering advantages), ensures
more consistent use in vector MLIR code (easier to read), and reduces the
amount of code duplication to lower memrefs into base addresses considerably
(making codegen less error-prone).

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D94278

13 files changed:
mlir/include/mlir/Dialect/Vector/VectorOps.td
mlir/integration_test/Dialect/Vector/CPU/test-compress.mlir
mlir/integration_test/Dialect/Vector/CPU/test-expand.mlir
mlir/integration_test/Dialect/Vector/CPU/test-maskedload.mlir
mlir/integration_test/Dialect/Vector/CPU/test-maskedstore.mlir
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/invalid.mlir
mlir/test/Dialect/Vector/ops.mlir
mlir/test/Dialect/Vector/vector-mem-transforms.mlir
mlir/test/Dialect/Vector/vector-transforms.mlir
mlir/test/lib/Transforms/TestVectorTransforms.cpp

index 0a98b9f..0aa4950 100644 (file)
@@ -1317,6 +1317,7 @@ def Vector_TransferWriteOp :
 def Vector_MaskedLoadOp :
   Vector_Op<"maskedload">,
     Arguments<(ins AnyMemRef:$base,
+               Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1325,12 +1326,12 @@ def Vector_MaskedLoadOp :
 
   let description = [{
     The masked load reads elements from memory into a 1-D vector as defined
-    by a base and a 1-D mask vector. When the mask is set, the element is read
-    from memory. Otherwise, the corresponding element is taken from a 1-D
-    pass-through vector. Informally the semantics are:
+    by a base with indices and a 1-D mask vector. When the mask is set, the
+    element is read from memory. Otherwise, the corresponding element is taken
+    from a 1-D pass-through vector. Informally the semantics are:
     ```
-    result[0] := mask[0] ? MEM[base+0] : pass_thru[0]
-    result[1] := mask[1] ? MEM[base+1] : pass_thru[1]
+    result[0] := mask[0] ? base[i+0] : pass_thru[0]
+    result[1] := mask[1] ? base[i+1] : pass_thru[1]
     etc.
     ```
     The masked load can be used directly where applicable, or can be used
@@ -1342,7 +1343,7 @@ def Vector_MaskedLoadOp :
     Example:
 
     ```mlir
-    %0 = vector.maskedload %base, %mask, %pass_thru
+    %0 = vector.maskedload %base[%i], %mask, %pass_thru
        : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
     ```
   }];
@@ -1360,7 +1361,7 @@ def Vector_MaskedLoadOp :
       return result().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
   let hasCanonicalizer = 1;
 }
@@ -1368,6 +1369,7 @@ def Vector_MaskedLoadOp :
 def Vector_MaskedStoreOp :
   Vector_Op<"maskedstore">,
     Arguments<(ins AnyMemRef:$base,
+               Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$value)> {
 
@@ -1375,12 +1377,12 @@ def Vector_MaskedStoreOp :
 
   let description = [{
     The masked store operation writes elements from a 1-D vector into memory
-    as defined by a base and a 1-D mask vector. When the mask is set, the
-    corresponding element from the vector is written to memory. Otherwise,
+    as defined by a base with indices and a 1-D mask vector. When the mask is
+    set, the corresponding element from the vector is written to memory. Otherwise,
     no action is taken for the element. Informally the semantics are:
     ```
-    if (mask[0]) MEM[base+0] = value[0]
-    if (mask[1]) MEM[base+1] = value[1]
+    if (mask[0]) base[i+0] = value[0]
+    if (mask[1]) base[i+1] = value[1]
     etc.
     ```
     The masked store can be used directly where applicable, or can be used
@@ -1392,7 +1394,7 @@ def Vector_MaskedStoreOp :
     Example:
 
     ```mlir
-    vector.maskedstore %base, %mask, %value
+    vector.maskedstore %base[%i], %mask, %value
       : memref<?xf32>, vector<8xi1>, vector<8xf32>
     ```
   }];
@@ -1407,8 +1409,8 @@ def Vector_MaskedStoreOp :
       return value().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
-    "type($mask) `,` type($value) `into` type($base)";
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
+    "type($base) `,` type($mask) `,` type($value)";
   let hasCanonicalizer = 1;
 }
 
@@ -1430,8 +1432,8 @@ def Vector_GatherOp :
     semantics are:
     ```
     if (!defined(pass_thru)) pass_thru = [undef, .., undef]
-    result[0] := mask[0] ? MEM[base + index[0]] : pass_thru[0]
-    result[1] := mask[1] ? MEM[base + index[1]] : pass_thru[1]
+    result[0] := mask[0] ? base[index[0]] : pass_thru[0]
+    result[1] := mask[1] ? base[index[1]] : pass_thru[1]
     etc.
     ```
     The vector dialect leaves out-of-bounds behavior undefined.
@@ -1487,8 +1489,8 @@ def Vector_ScatterOp :
     bit in a 1-D mask vector is set. Otherwise, no action is taken for that
     element. Informally the semantics are:
     ```
-    if (mask[0]) MEM[base + index[0]] = value[0]
-    if (mask[1]) MEM[base + index[1]] = value[1]
+    if (mask[0]) base[index[0]] = value[0]
+    if (mask[1]) base[index[1]] = value[1]
     etc.
     ```
     The vector dialect leaves out-of-bounds and repeated index behavior
@@ -1531,6 +1533,7 @@ def Vector_ScatterOp :
 def Vector_ExpandLoadOp :
   Vector_Op<"expandload">,
     Arguments<(ins AnyMemRef:$base,
+               Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$pass_thru)>,
     Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1539,13 +1542,13 @@ def Vector_ExpandLoadOp :
 
   let description = [{
     The expand load reads elements from memory into a 1-D vector as defined
-    by a base and a 1-D mask vector. When the mask is set, the next element
-    is read from memory. Otherwise, the corresponding element is taken from
-    a 1-D pass-through vector. Informally the semantics are:
+    by a base with indices and a 1-D mask vector. When the mask is set, the
+    next element is read from memory. Otherwise, the corresponding element
+    is taken from a 1-D pass-through vector. Informally the semantics are:
     ```
-    index = base
-    result[0] := mask[0] ? MEM[index++] : pass_thru[0]
-    result[1] := mask[1] ? MEM[index++] : pass_thru[1]
+    index = i
+    result[0] := mask[0] ? base[index++] : pass_thru[0]
+    result[1] := mask[1] ? base[index++] : pass_thru[1]
     etc.
     ```
     Note that the index increment is done conditionally.
@@ -1559,7 +1562,7 @@ def Vector_ExpandLoadOp :
     Example:
 
     ```mlir
-    %0 = vector.expandload %base, %mask, %pass_thru
+    %0 = vector.expandload %base[%i], %mask, %pass_thru
        : memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
     ```
   }];
@@ -1577,7 +1580,7 @@ def Vector_ExpandLoadOp :
       return result().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `,` $mask `,` $pass_thru attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
     "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
   let hasCanonicalizer = 1;
 }
@@ -1585,6 +1588,7 @@ def Vector_ExpandLoadOp :
 def Vector_CompressStoreOp :
   Vector_Op<"compressstore">,
     Arguments<(ins AnyMemRef:$base,
+               Variadic<Index>:$indices,
                VectorOfRankAndType<[1], [I1]>:$mask,
                VectorOfRank<[1]>:$value)> {
 
@@ -1592,13 +1596,13 @@ def Vector_CompressStoreOp :
 
   let description = [{
     The compress store operation writes elements from a 1-D vector into memory
-    as defined by a base and a 1-D mask vector. When the mask is set, the
-    corresponding element from the vector is written next to memory. Otherwise,
-    no action is taken for the element. Informally the semantics are:
+    as defined by a base with indices and a 1-D mask vector. When the mask is
+    set, the corresponding element from the vector is written next to memory.
+    Otherwise, no action is taken for the element. Informally the semantics are:
     ```
-    index = base
-    if (mask[0]) MEM[index++] = value[0]
-    if (mask[1]) MEM[index++] = value[1]
+    index = i
+    if (mask[0]) base[index++] = value[0]
+    if (mask[1]) base[index++] = value[1]
     etc.
     ```
     Note that the index increment is done conditionally.
@@ -1612,7 +1616,7 @@ def Vector_CompressStoreOp :
     Example:
 
     ```mlir
-    vector.compressstore %base, %mask, %value
+    vector.compressstore %base[%i], %mask, %value
       : memref<?xf32>, vector<8xi1>, vector<8xf32>
     ```
   }];
@@ -1627,7 +1631,7 @@ def Vector_CompressStoreOp :
       return value().getType().cast<VectorType>();
     }
   }];
-  let assemblyFormat = "$base `,` $mask `,` $value attr-dict `:` "
+  let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
     "type($base) `,` type($mask) `,` type($value)";
   let hasCanonicalizer = 1;
 }
index 6310d6e..7602220 100644 (file)
@@ -5,7 +5,16 @@
 
 func @compress16(%base: memref<?xf32>,
                  %mask: vector<16xi1>, %value: vector<16xf32>) {
-  vector.compressstore %base, %mask, %value
+  %c0 = constant 0: index
+  vector.compressstore %base[%c0], %mask, %value
+    : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+func @compress16_at8(%base: memref<?xf32>,
+                     %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c8 = constant 8: index
+  vector.compressstore %base[%c8], %mask, %value
     : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
@@ -86,5 +95,10 @@ func @entry() {
   call @printmem16(%A) : (memref<?xf32>) -> ()
   // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
 
+  call @compress16_at8(%A, %some1, %value)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK-NEXT: ( 0, 1, 2, 3, 11, 13, 15, 7, 0, 1, 2, 3, 12, 13, 14, 15 )
+
   return
 }
index 74118fc..b63294f 100644 (file)
@@ -5,8 +5,18 @@
 
 func @expand16(%base: memref<?xf32>,
                %mask: vector<16xi1>,
-              %pass_thru: vector<16xf32>) -> vector<16xf32> {
-  %e = vector.expandload %base, %mask, %pass_thru
+               %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0: index
+  %e = vector.expandload %base[%c0], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %e : vector<16xf32>
+}
+
+func @expand16_at8(%base: memref<?xf32>,
+                   %mask: vector<16xi1>,
+                   %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c8 = constant 8: index
+  %e = vector.expandload %base[%c8], %mask, %pass_thru
     : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %e : vector<16xf32>
 }
@@ -78,5 +88,10 @@ func @entry() {
   vector.print %e6 : vector<16xf32>
   // CHECK-NEXT: ( -7, 0, 7.7, 1, -7, -7, -7, 2, -7, -7, -7, 3, -7, 4, 7.7, 5 )
 
+  %e7 = call @expand16_at8(%A, %some1, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %e7 : vector<16xf32>
+  // CHECK-NEXT: ( 8, 9, 10, 11, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7, -7 )
+
   return
 }
index 6c6f6ea..d5353af 100644 (file)
@@ -5,7 +5,16 @@
 
 func @maskedload16(%base: memref<?xf32>, %mask: vector<16xi1>,
                    %pass_thru: vector<16xf32>) -> vector<16xf32> {
-  %ld = vector.maskedload %base, %mask, %pass_thru
+  %c0 = constant 0: index
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
+
+func @maskedload16_at8(%base: memref<?xf32>, %mask: vector<16xi1>,
+                       %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c8 = constant 8: index
+  %ld = vector.maskedload %base[%c8], %mask, %pass_thru
     : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
@@ -61,6 +70,11 @@ func @entry() {
   vector.print %l4 : vector<16xf32>
   // CHECK: ( -7, 1, 2, 3, 4, 5, 6, 7, -7, -7, -7, -7, -7, 13, 14, -7 )
 
+  %l5 = call @maskedload16_at8(%A, %some, %pass)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> (vector<16xf32>)
+  vector.print %l5 : vector<16xf32>
+  // CHECK: ( 8, 9, 10, 11, 12, 13, 14, 15, -7, -7, -7, -7, -7, -7, -7, -7 )
+
   return
 }
 
index d0132f6..a987898 100644 (file)
@@ -5,8 +5,17 @@
 
 func @maskedstore16(%base: memref<?xf32>,
                     %mask: vector<16xi1>, %value: vector<16xf32>) {
-  vector.maskedstore %base, %mask, %value
-    : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  %c0 = constant 0: index
+  vector.maskedstore %base[%c0], %mask, %value
+    : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  return
+}
+
+func @maskedstore16_at8(%base: memref<?xf32>,
+                        %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c8 = constant 8: index
+  vector.maskedstore %base[%c8], %mask, %value
+    : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
@@ -85,5 +94,10 @@ func @entry() {
   call @printmem16(%A) : (memref<?xf32>) -> ()
   // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 )
 
+  call @maskedstore16_at8(%A, %some, %val)
+    : (memref<?xf32>, vector<16xi1>, vector<16xf32>) -> ()
+  call @printmem16(%A) : (memref<?xf32>) -> ()
+  // CHECK: ( 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7 )
+
   return
 }
index 5ad266c..5dd0b02 100644 (file)
@@ -173,33 +173,7 @@ static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
   return success();
 }
 
-// Helper that returns a pointer given a memref base.
-static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
-                                Location loc, Value memref,
-                                MemRefType memRefType, Value &ptr) {
-  Value base;
-  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
-    return failure();
-  auto pType = MemRefDescriptor(memref).getElementPtrType();
-  ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
-  return success();
-}
-
-// Helper that returns a bit-casted pointer given a memref base.
-static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
-                                Location loc, Value memref,
-                                MemRefType memRefType, Type type, Value &ptr) {
-  Value base;
-  if (failed(getBase(rewriter, loc, memref, memRefType, base)))
-    return failure();
-  auto pType = LLVM::LLVMPointerType::get(type);
-  base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
-  ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
-  return success();
-}
-
-// Helper that returns vector of pointers given a memref base and an index
-// vector.
+// Helper that returns vector of pointers given a memref base with index vector.
 static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
                                     Location loc, Value memref, Value indices,
                                     MemRefType memRefType, VectorType vType,
@@ -213,6 +187,18 @@ static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+// Casts a strided element pointer to a vector pointer. The vector pointer
+// would always be on address space 0, therefore addrspacecast shall be
+// used when source/dst memrefs are not on address space 0.
+static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc,
+                         Value ptr, MemRefType memRefType, Type vt) {
+  auto pType =
+      LLVM::LLVMPointerType::get(vt.template cast<LLVM::LLVMFixedVectorType>());
+  if (memRefType.getMemorySpace() == 0)
+    return rewriter.create<LLVM::BitcastOp>(loc, pType, ptr);
+  return rewriter.create<LLVM::AddrSpaceCastOp>(loc, pType, ptr);
+}
+
 static LogicalResult
 replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter,
                                  LLVMTypeConverter &typeConverter, Location loc,
@@ -343,18 +329,18 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = load->getLoc();
     auto adaptor = vector::MaskedLoadOpAdaptor(operands);
+    MemRefType memRefType = load.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), load.getMemRefType(),
-                                  align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     auto vtype = typeConverter->convertType(load.getResultVectorType());
-    Value ptr;
-    if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(),
-                          vtype, ptr)))
-      return failure();
+    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                               adaptor.indices(), rewriter);
+    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
 
     rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
         load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(),
@@ -374,18 +360,18 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = store->getLoc();
     auto adaptor = vector::MaskedStoreOpAdaptor(operands);
+    MemRefType memRefType = store.getMemRefType();
 
     // Resolve alignment.
     unsigned align;
-    if (failed(getMemRefAlignment(*getTypeConverter(), store.getMemRefType(),
-                                  align)))
+    if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
       return failure();
 
+    // Resolve address.
     auto vtype = typeConverter->convertType(store.getValueVectorType());
-    Value ptr;
-    if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(),
-                          vtype, ptr)))
-      return failure();
+    Value dataPtr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                               adaptor.indices(), rewriter);
+    Value ptr = castDataPtr(rewriter, loc, dataPtr, memRefType, vtype);
 
     rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
         store, adaptor.value(), ptr, adaptor.mask(),
@@ -473,16 +459,15 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = expand->getLoc();
     auto adaptor = vector::ExpandLoadOpAdaptor(operands);
+    MemRefType memRefType = expand.getMemRefType();
 
-    Value ptr;
-    if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(),
-                          ptr)))
-      return failure();
+    // Resolve address.
+    auto vtype = typeConverter->convertType(expand.getResultVectorType());
+    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                           adaptor.indices(), rewriter);
 
-    auto vType = expand.getResultVectorType();
     rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
-        expand, typeConverter->convertType(vType), ptr, adaptor.mask(),
-        adaptor.pass_thru());
+        expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
     return success();
   }
 };
@@ -498,11 +483,11 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = compress->getLoc();
     auto adaptor = vector::CompressStoreOpAdaptor(operands);
+    MemRefType memRefType = compress.getMemRefType();
 
-    Value ptr;
-    if (failed(getBasePtr(rewriter, loc, adaptor.base(),
-                          compress.getMemRefType(), ptr)))
-      return failure();
+    // Resolve address.
+    Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
+                                           adaptor.indices(), rewriter);
 
     rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
         compress, adaptor.value(), ptr, adaptor.mask());
@@ -1223,21 +1208,11 @@ public:
     }
 
     // 1. Get the source/dst address as an LLVM vector pointer.
-    //    The vector pointer would always be on address space 0, therefore
-    //    addrspacecast shall be used when source/dst memrefs are not on
-    //    address space 0.
-    // TODO: support alignment when possible.
+    VectorType vtp = xferOp.getVectorType();
     Value dataPtr = this->getStridedElementPtr(
         loc, memRefType, adaptor.source(), adaptor.indices(), rewriter);
-    auto vecTy = toLLVMTy(xferOp.getVectorType())
-                     .template cast<LLVM::LLVMFixedVectorType>();
-    Value vectorDataPtr;
-    if (memRefType.getMemorySpace() == 0)
-      vectorDataPtr = rewriter.create<LLVM::BitcastOp>(
-          loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
-    else
-      vectorDataPtr = rewriter.create<LLVM::AddrSpaceCastOp>(
-          loc, LLVM::LLVMPointerType::get(vecTy), dataPtr);
+    Value vectorDataPtr =
+        castDataPtr(rewriter, loc, dataPtr, memRefType, toLLVMTy(vtp));
 
     if (!xferOp.isMaskedDim(0))
       return replaceTransferOpWithLoadOrStore(rewriter,
@@ -1251,7 +1226,7 @@ public:
     //
     // TODO: when the leaf transfer rank is k > 1, we need the last `k`
     //       dimensions here.
-    unsigned vecWidth = vecTy.getNumElements();
+    unsigned vecWidth = vtp.getNumElements();
     unsigned lastIndex = llvm::size(xferOp.indices()) - 1;
     Value off = xferOp.indices()[lastIndex];
     Value dim = rewriter.create<DimOp>(loc, xferOp.source(), lastIndex);
index 5d9c2e6..318ca1e 100644 (file)
@@ -76,20 +76,6 @@ static MaskFormat get1DMaskFormat(Value mask) {
   return MaskFormat::Unknown;
 }
 
-/// Helper method to cast a 1-D memref<10xf32> "base" into a
-/// memref<vector<10xf32>> in the output parameter "newBase",
-/// using the 'element' vector type "vt". Returns true on success.
-static bool castedToMemRef(Location loc, Value base, MemRefType mt,
-                           VectorType vt, PatternRewriter &rewriter,
-                           Value &newBase) {
-  // The vector.type_cast operation does not accept unknown memref<?xf32>.
-  // TODO: generalize the cast and accept this case too
-  if (!mt.hasStaticShape())
-    return false;
-  newBase = rewriter.create<TypeCastOp>(loc, MemRefType::get({}, vt), base);
-  return true;
-}
-
 //===----------------------------------------------------------------------===//
 // VectorDialect
 //===----------------------------------------------------------------------===//
@@ -2380,13 +2366,10 @@ public:
   using OpRewritePattern<MaskedLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedLoadOp load,
                                 PatternRewriter &rewriter) const override {
-    Value newBase;
     switch (get1DMaskFormat(load.mask())) {
     case MaskFormat::AllTrue:
-      if (!castedToMemRef(load.getLoc(), load.base(), load.getMemRefType(),
-                          load.getResultVectorType(), rewriter, newBase))
-        return failure();
-      rewriter.replaceOpWithNewOp<LoadOp>(load, newBase);
+      rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+          load, load.getType(), load.base(), load.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.replaceOp(load, load.pass_thru());
@@ -2426,13 +2409,10 @@ public:
   using OpRewritePattern<MaskedStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(MaskedStoreOp store,
                                 PatternRewriter &rewriter) const override {
-    Value newBase;
     switch (get1DMaskFormat(store.mask())) {
     case MaskFormat::AllTrue:
-      if (!castedToMemRef(store.getLoc(), store.base(), store.getMemRefType(),
-                          store.getValueVectorType(), rewriter, newBase))
-        return failure();
-      rewriter.replaceOpWithNewOp<StoreOp>(store, store.value(), newBase);
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+          store, store.value(), store.base(), store.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(store);
@@ -2568,14 +2548,10 @@ public:
   using OpRewritePattern<ExpandLoadOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(ExpandLoadOp expand,
                                 PatternRewriter &rewriter) const override {
-    Value newBase;
     switch (get1DMaskFormat(expand.mask())) {
     case MaskFormat::AllTrue:
-      if (!castedToMemRef(expand.getLoc(), expand.base(),
-                          expand.getMemRefType(), expand.getResultVectorType(),
-                          rewriter, newBase))
-        return failure();
-      rewriter.replaceOpWithNewOp<LoadOp>(expand, newBase);
+      rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
+          expand, expand.getType(), expand.base(), expand.indices(), false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.replaceOp(expand, expand.pass_thru());
@@ -2615,14 +2591,11 @@ public:
   using OpRewritePattern<CompressStoreOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(CompressStoreOp compress,
                                 PatternRewriter &rewriter) const override {
-    Value newBase;
     switch (get1DMaskFormat(compress.mask())) {
     case MaskFormat::AllTrue:
-      if (!castedToMemRef(compress.getLoc(), compress.base(),
-                          compress.getMemRefType(),
-                          compress.getValueVectorType(), rewriter, newBase))
-        return failure();
-      rewriter.replaceOpWithNewOp<StoreOp>(compress, compress.value(), newBase);
+      rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+          compress, compress.value(), compress.base(), compress.indices(),
+          false);
       return success();
     case MaskFormat::AllFalse:
       rewriter.eraseOp(compress);
index c1e4e08..5c0c965 100644 (file)
@@ -1070,23 +1070,29 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> {
 // CHECK:       llvm.return %[[T]] : !llvm.vec<16 x f32>
 
 func @masked_load_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
-  %0 = vector.maskedload %arg0, %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %c0 = constant 0: index
+  %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %0 : vector<16xf32>
 }
 
 // CHECK-LABEL: func @masked_load_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
-// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[P]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
+// CHECK: %[[L:.*]] = llvm.intr.masked.load %[[B]], %{{.*}}, %{{.*}} {alignment = 4 : i32} : (!llvm.ptr<vec<16 x f32>>, !llvm.vec<16 x i1>, !llvm.vec<16 x f32>) -> !llvm.vec<16 x f32>
 // CHECK: llvm.return %[[L]] : !llvm.vec<16 x f32>
 
 func @masked_store_op(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
-  vector.maskedstore %arg0, %arg1, %arg2 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  %c0 = constant 0: index
+  vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
 // CHECK-LABEL: func @masked_store_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<vec<16 x f32>>) -> !llvm.ptr<vec<16 x f32>>
-// CHECK: llvm.intr.masked.store %{{.*}}, %[[P]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+// CHECK: %[[B:.*]] = llvm.bitcast %[[P]] : !llvm.ptr<f32> to !llvm.ptr<vec<16 x f32>>
+// CHECK: llvm.intr.masked.store %{{.*}}, %[[B]], %{{.*}} {alignment = 4 : i32} : !llvm.vec<16 x f32>, !llvm.vec<16 x i1> into !llvm.ptr<vec<16 x f32>>
 // CHECK: llvm.return
 
 func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
@@ -1110,21 +1116,25 @@ func @scatter_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>
 // CHECK: llvm.return
 
 func @expand_load_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) -> vector<11xf32> {
-  %0 = vector.expandload %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
+  %c0 = constant 0: index
+  %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32> into vector<11xf32>
   return %0 : vector<11xf32>
 }
 
 // CHECK-LABEL: func @expand_load_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
 // CHECK: %[[E:.*]] = "llvm.intr.masked.expandload"(%[[P]], %{{.*}}, %{{.*}}) : (!llvm.ptr<f32>, !llvm.vec<11 x i1>, !llvm.vec<11 x f32>) -> !llvm.vec<11 x f32>
 // CHECK: llvm.return %[[E]] : !llvm.vec<11 x f32>
 
 func @compress_store_op(%arg0: memref<?xf32>, %arg1: vector<11xi1>, %arg2: vector<11xf32>) {
-  vector.compressstore %arg0, %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
+  %c0 = constant 0: index
+  vector.compressstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<11xi1>, vector<11xf32>
   return
 }
 
 // CHECK-LABEL: func @compress_store_op
-// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[] : (!llvm.ptr<f32>) -> !llvm.ptr<f32>
+// CHECK: %[[C:.*]] = llvm.mlir.constant(0 : index) : i64
+// CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
 // CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (!llvm.vec<11 x f32>, !llvm.ptr<f32>, !llvm.vec<11 x i1>) -> ()
 // CHECK: llvm.return
index 62eaa4e..8cadafa 100644 (file)
@@ -1199,36 +1199,41 @@ func @type_cast_layout(%arg0: memref<4x3xf32, affine_map<(d0, d1)[s0, s1, s2] ->
 // -----
 
 func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedload' op base and result element type should match}}
-  %0 = vector.maskedload %base, %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
 func @maskedload_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %pass: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedload' op expected result dim to match mask dim}}
-  %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<15xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
 func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xi32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}}
-  %0 = vector.maskedload %base, %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xi32> into vector<16xf32>
 }
 
 // -----
 
 func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}}
-  vector.maskedstore %base, %mask, %value : vector<16xi1>, vector<16xf32> into memref<?xf64>
+  vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
 // -----
 
 func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.maskedstore' op expected value dim to match mask dim}}
-  vector.maskedstore %base, %mask, %value : vector<15xi1>, vector<16xf32> into memref<?xf32>
+  vector.maskedstore %base[%c0], %mask, %value : memref<?xf32>, vector<15xi1>, vector<16xf32>
 }
 
 // -----
@@ -1297,36 +1302,41 @@ func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
 // -----
 
 func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.expandload' op base and result element type should match}}
-  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
 func @expand_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %pass_thru: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.expandload' op expected result dim to match mask dim}}
-  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<17xi1>, vector<16xf32> into vector<16xf32>
 }
 
 // -----
 
 func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass_thru: vector<17xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.expandload' op expected pass_thru of same type as result type}}
-  %0 = vector.expandload %base, %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf32>, vector<16xi1>, vector<17xf32> into vector<16xf32>
 }
 
 // -----
 
 func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
-  vector.compressstore %base, %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
 // -----
 
 func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   // expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
-  vector.compressstore %base, %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
+  vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
 }
 
 // -----
index 07e9d8d..60890e5 100644 (file)
@@ -452,10 +452,11 @@ func @flat_transpose_int(%arg0: vector<16xi32>) -> vector<16xi32> {
 
 // CHECK-LABEL: @masked_load_and_store
 func @masked_load_and_store(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
-  // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  %0 = vector.maskedload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  // CHECK: vector.maskedstore %{{.*}}, %{{.*}}, %[[X]] : vector<16xi1>, vector<16xf32> into memref<?xf32>
-  vector.maskedstore %base, %mask, %0 : vector<16xi1>, vector<16xf32> into memref<?xf32>
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.maskedload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.maskedstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
@@ -472,10 +473,11 @@ func @gather_and_scatter(%base: memref<?xf32>, %indices: vector<16xi32>, %mask:
 
 // CHECK-LABEL: @expand_and_compress
 func @expand_and_compress(%base: memref<?xf32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) {
-  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}, %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  %0 = vector.expandload %base, %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-  // CHECK: vector.compressstore %{{.*}}, %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
-  vector.compressstore %base, %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  %c0 = constant 0 : index
+  // CHECK: %[[X:.*]] = vector.expandload %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  %0 = vector.expandload %base[%c0], %mask, %passthru : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  // CHECK: vector.compressstore %{{.*}}[{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base[%c0], %mask, %0 : memref<?xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
index 7d79d8b..f9d7903 100644 (file)
@@ -1,82 +1,93 @@
 // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
 
-//
-// TODO: optimize this one too!
-//
-// CHECK-LABEL: func @maskedload0(
-// CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[M:.*]] = vector.constant_mask
-// CHECK-NEXT: %[[T:.*]] = vector.maskedload %[[A0]], %[[M]], %[[A1]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
-// CHECK-NEXT: return %[[T]] : vector<16xf32>
-
+// CHECK-LABEL:   func @maskedload0(
+// CHECK-SAME:                      %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG:       %[[C:.*]] = constant 0 : index
+// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @maskedload0(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  %ld = vector.maskedload %base, %mask, %pass_thru
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru
     : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @maskedload1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return %[[T1]] : vector<16xf32>
-
+// CHECK-LABEL:   func @maskedload1(
+// CHECK-SAME:                      %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG:       %[[C:.*]] = constant 0 : index
+// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  %ld = vector.maskedload %base, %mask, %pass_thru
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru
     : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @maskedload2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-
+// CHECK-LABEL:   func @maskedload2(
+// CHECK-SAME:                      %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT:      return %[[A1]] : vector<16xf32>
 func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [0] : vector<16xi1>
-  %ld = vector.maskedload %base, %mask, %pass_thru
+  %ld = vector.maskedload %base[%c0], %mask, %pass_thru
     : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @maskedstore1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return
+// CHECK-LABEL:   func @maskedload3(
+// CHECK-SAME:                      %[[A0:.*]]: memref<?xf32>,
+// CHECK-SAME:                      %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG:       %[[C:.*]] = constant 8 : index
+// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<?xf32>, vector<16xf32>
+// CHECK-NEXT:      return %[[T]] : vector<16xf32>
+func @maskedload3(%base: memref<?xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c8 = constant 8 : index
+  %mask = vector.constant_mask [16] : vector<16xi1>
+  %ld = vector.maskedload %base[%c8], %mask, %pass_thru
+    : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
+  return %ld : vector<16xf32>
+}
 
+// CHECK-LABEL:   func @maskedstore1(
+// CHECK-SAME:                       %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                       %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      %[[C:.*]] = constant 0 : index
+// CHECK-NEXT:      vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT:      return
 func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  vector.maskedstore %base, %mask, %value
-    : vector<16xi1>, vector<16xf32> into memref<16xf32>
+  vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
-// CHECK-LABEL: func @maskedstore2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL:   func @maskedstore2(
+// CHECK-SAME:                       %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                       %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      return
 func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>)  {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [0] : vector<16xi1>
-  vector.maskedstore %base, %mask, %value
-    : vector<16xi1>, vector<16xf32> into memref<16xf32>
+  vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
-// CHECK-LABEL: func @gather1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: %[[T1:.*]] = vector.gather %[[A0]], %[[A1]], %[[T0]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
-// CHECK-NEXT: return %1 : vector<16xf32>
-
+// CHECK-LABEL:   func @gather1(
+// CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      %[[G:.*]] = vector.gather %[[A0]], %[[A1]], %[[M]], %[[A2]] : (memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>) -> vector<16xf32>
+// CHECK-NEXT:      return %[[G]] : vector<16xf32>
 func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %mask = vector.constant_mask [16] : vector<16xi1>
   %ld = vector.gather %base, %indices, %mask, %pass_thru
@@ -84,12 +95,11 @@ func @gather1(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @gather2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A2]] : vector<16xf32>
-
+// CHECK-LABEL:   func @gather2(
+// CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                  %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME:                  %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT:      return %[[A2]] : vector<16xf32>
 func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
   %mask = vector.constant_mask [0] : vector<16xi1>
   %ld = vector.gather %base, %indices, %mask, %pass_thru
@@ -97,14 +107,13 @@ func @gather2(%base: memref<16xf32>, %indices: vector<16xi32>, %pass_thru: vecto
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @scatter1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.constant_mask [16] : vector<16xi1>
-// CHECK-NEXT: vector.scatter %[[A0]], %[[A1]], %[[T0]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
-// CHECK-NEXT: return
-
+// CHECK-LABEL:   func @scatter1(
+// CHECK-SAME:                   %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      %[[M:.*]] = vector.constant_mask [16] : vector<16xi1>
+// CHECK-NEXT:      vector.scatter %[[A0]], %[[A1]], %[[M]], %[[A2]] : vector<16xi32>, vector<16xi1>, vector<16xf32> into memref<16xf32>
+// CHECK-NEXT:      return
 func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
   %mask = vector.constant_mask [16] : vector<16xi1>
   vector.scatter %base, %indices, %mask, %value
@@ -112,12 +121,11 @@ func @scatter1(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
   return
 }
 
-// CHECK-LABEL: func @scatter2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
-// CHECK-SAME: %[[A2:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL:   func @scatter2(
+// CHECK-SAME:                   %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                   %[[A1:.*]]: vector<16xi32>,
+// CHECK-SAME:                   %[[A2:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      return
 func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<16xf32>) {
   %0 = vector.type_cast %base : memref<16xf32> to memref<vector<16xf32>>
   %mask = vector.constant_mask [0] : vector<16xi1>
@@ -126,52 +134,53 @@ func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vector<1
   return
 }
 
-// CHECK-LABEL: func @expand1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: %[[T1:.*]] = load %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return %[[T1]] : vector<16xf32>
-
+// CHECK-LABEL:   func @expand1(
+// CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                  %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-DAG:       %[[C:.*]] = constant 0 : index
+// CHECK-DAG:       %[[D:.*]] = constant 0.000000e+00 : f32
+// CHECK-NEXT:      %[[T:.*]] = vector.transfer_read %[[A0]][%[[C]]], %[[D]] {masked = [false]} : memref<16xf32>, vector<16xf32>
+// CHECK-NEXT:      return %[[T]] : vector<16xf32>
 func @expand1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  %ld = vector.expandload %base, %mask, %pass_thru
+  %ld = vector.expandload %base[%c0], %mask, %pass_thru
     : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @expand2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return %[[A1]] : vector<16xf32>
-
+// CHECK-LABEL:   func @expand2(
+// CHECK-SAME:                  %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                  %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
+// CHECK-NEXT:      return %[[A1]] : vector<16xf32>
 func @expand2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vector<16xf32> {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [0] : vector<16xi1>
-  %ld = vector.expandload %base, %mask, %pass_thru
+  %ld = vector.expandload %base[%c0], %mask, %pass_thru
     : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
   return %ld : vector<16xf32>
 }
 
-// CHECK-LABEL: func @compress1(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: %[[T0:.*]] = vector.type_cast %[[A0]] : memref<16xf32> to memref<vector<16xf32>>
-// CHECK-NEXT: store %[[A1]], %[[T0]][] : memref<vector<16xf32>>
-// CHECK-NEXT: return
-
+// CHECK-LABEL:   func @compress1(
+// CHECK-SAME:                    %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                    %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      %[[C:.*]] = constant 0 : index
+// CHECK-NEXT:      vector.transfer_write %[[A1]], %[[A0]][%[[C]]] {masked = [false]} : vector<16xf32>, memref<16xf32>
+// CHECK-NEXT:      return
 func @compress1(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [16] : vector<16xi1>
-  vector.compressstore %base, %mask, %value  : memref<16xf32>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base[%c0], %mask, %value  : memref<16xf32>, vector<16xi1>, vector<16xf32>
   return
 }
 
-// CHECK-LABEL: func @compress2(
-// CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
-// CHECK-SAME: %[[A1:.*]]: vector<16xf32>)
-// CHECK-NEXT: return
-
+// CHECK-LABEL:   func @compress2(
+// CHECK-SAME:                    %[[A0:.*]]: memref<16xf32>,
+// CHECK-SAME:                    %[[A1:.*]]: vector<16xf32>) {
+// CHECK-NEXT:      return
 func @compress2(%base: memref<16xf32>, %value: vector<16xf32>) {
+  %c0 = constant 0 : index
   %mask = vector.constant_mask [0] : vector<16xi1>
-  vector.compressstore %base, %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
+  vector.compressstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32>
   return
 }
index 4a58261..754c7cc 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s
+// RUN: mlir-opt %s -test-vector-to-vector-conversion="unroll" | FileCheck %s
 
 // CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
 
index 572cd1c..11e4ee2 100644 (file)
@@ -24,13 +24,25 @@ namespace {
 
 struct TestVectorToVectorConversion
     : public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
+  TestVectorToVectorConversion() = default;
+  TestVectorToVectorConversion(const TestVectorToVectorConversion &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect>();
+  }
+
+  Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
+                      llvm::cl::init(false)};
+
   void runOnFunction() override {
     OwningRewritePatternList patterns;
     auto *ctx = &getContext();
-    patterns.insert<UnrollVectorPattern>(
-        ctx,
-        UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
-            filter));
+    if (unroll) {
+      patterns.insert<UnrollVectorPattern>(
+          ctx,
+          UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
+              filter));
+    }
     populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
     populateVectorToVectorTransformationPatterns(patterns, ctx);
     applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));