// Vectorize other ops as vector contraction (currently only matmul).
LLVM_DEBUG(dbgs() << dbgPref
<< "Rewrite linalg op as vector.contract: " << *op);
+ auto extractVectorTypeFromScalarView = [](Value v) {
+ MemRefType mt = v.getType().cast<MemRefType>();
+ return VectorType::get(mt.getShape(), mt.getElementType());
+ };
auto linalgOp = cast<linalg::LinalgOp>(op);
- Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
- Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
- Value memref = vector_type_cast(linalgOp.getOutputBuffer(0));
- Value c = std_load(memref);
+ Value viewA = linalgOp.getInput(0);
+ Value viewB = linalgOp.getInput(1);
+ Value viewC = linalgOp.getOutputBuffer(0);
+ Value zero = std_constant_index(0);
+ SmallVector<Value, 4> indicesA(linalgOp.getInputShapedType(0).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesB(linalgOp.getInputShapedType(1).getRank(),
+ zero);
+ SmallVector<Value, 4> indicesC(linalgOp.getOutputShapedType(0).getRank(),
+ zero);
+ Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA,
+ indicesA);
+ Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB,
+ indicesB);
+ Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC,
+ indicesC);
Value res = vector_contract(a, b, c, linalgOp.indexing_maps(),
linalgOp.iterator_types());
- std_store(res, memref);
+ vector_transfer_write(res, viewC, indicesC);
}
/// Check whether there is any interleaved use of any `values` between `firstOp`
return
}
// CHECK-LABEL: func @vectorization_test
-// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref<vector<8x16xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x16xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref<vector<16x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<16x32xf32>>
-// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref<vector<8x32xf32>>
-// CHECK: load %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x16xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<16x32xf32>
+// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32>, vector<8x32xf32>
// CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32>
-// CHECK: store %{{.*}}, %{{.*}}[] : memref<vector<8x32xf32>>
+// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32>
func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {