Summary: Vector contract patterns were only parameterized by a `vectorTransformsOptions`. As a result, even if an mlir file was containing several occurrences of `vector.contract`, all of them would be lowered in the same way. More granularity might be required . This Diff adds a `constraint` argument to each of these patterns which allows the user to specify with more precision on which `vector.contract` should each of the lowering apply.
Differential Revision: https://reviews.llvm.org/D83960
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
ContractionOpToMatmulOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
};
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
: public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
+
ContractionOpToOuterProductOpLowering(
vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context, FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult match(vector::ContractionOp op) const override;
void rewrite(vector::ContractionOp op,
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
};
/// Progressive lowering of ContractionOp.
class ContractionOpLowering : public OpRewritePattern<vector::ContractionOp> {
public:
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
+ using FilterConstraintType =
+ std::function<LogicalResult(vector::ContractionOp op)>;
+
+ static LogicalResult defaultFilter(vector::ContractionOp op) {
+ return success();
+ }
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions,
- MLIRContext *context)
+ MLIRContext *context,
+ FilterConstraintType constraint = defaultFilter)
: OpRewritePattern<vector::ContractionOp>(context),
- vectorTransformsOptions(vectorTransformsOptions) {}
+ vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {}
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override;
private:
/// Options to control the vector patterns.
vector::VectorTransformsOptions vectorTransformsOptions;
+ FilterConstraintType filter;
// Lower one parallel dimension.
Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
int64_t rhsIndex, PatternRewriter &rewriter) const;
vector::VectorContractLowering::Matmul)
return failure();
+ if (failed(filter(op)))
+ return failure();
+
auto iteratorTypes = op.iterator_types().getValue();
if (!isParallelIterator(iteratorTypes[0]) ||
!isParallelIterator(iteratorTypes[1]) ||
vector::VectorContractLowering::OuterProduct)
return failure();
+ if (failed(filter(op)))
+ return failure();
+
// Determine if the parallel/reduction structure matches something
// that can be expressed a reduction_size unrolled sequence.
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
// TODO: implement masks.
if (llvm::size(op.masks()) != 0)
return failure();
+
+ if (failed(filter(op)))
+ return failure();
+
// TODO: support mixed mode contract lowering.
if (op.getLhsType().getElementType() !=
getElementTypeOrSelf(op.getAccType()) ||
// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX
// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT
+// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT
#dotp_accesses = [
affine_map<(i) -> (i)>,
: vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32>
return %0 : vector<3x2xf32>
}
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32>
+// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>)
+-> vector<4x4xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered
+// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>,
+// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32>
+// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]]
+func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>)
+-> vector<3x4xf32>
+{
+ %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
+ : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
+
+
+
+
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
+ Option<bool> lowerToFilterOuterProduct{
+ *this, "vector-filter-outerproduct",
+ llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
+ "vectors of size 4."),
+ llvm::cl::init(false)};
void runOnFunction() override {
OwningRewritePatternList patterns;
return;
}
+ // Test on one pattern in isolation.
+ if (lowerToFilterOuterProduct) {
+ VectorContractLowering lowering = VectorContractLowering::OuterProduct;
+ VectorTransformsOptions options{lowering};
+ patterns.insert<ContractionOpToOuterProductOpLowering>(
+ options, &getContext(), [](vector::ContractionOp op) {
+ // Only lowers vector.contract where the lhs as a type vector<MxNx?>
+ // where M is not 4.
+ if (op.getRhsType().getShape()[0] == 4)
+ return failure();
+ return success();
+ });
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ return;
+ }
+
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)