From 89e19e8eddd6dd0dc38d595b6784fb9ce65d9972 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 21 Jan 2020 19:37:18 -0500 Subject: [PATCH] [mlir][Linalg] Add tensor support to Linalg EDSC Builders Summary: This diff extends the Linalg EDSC builders so we can easily create mixed tensor/buffer linalg.generic ops. This is expected to be useful for HLO -> Linalg lowering. The `StructuredIndexed` struct is made to derive from `ValueHandle` and can now capture a type + indexing expressions. This is used to represent return tensors. Pointwise unary and binary builders are extended to allow both output buffers and return tensors. This has implications on the number of region arguments. Reviewers: ftynse, herhut, hanchung, asaadaldien, stellaraccident Reviewed By: asaadaldien Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72863 --- mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h | 134 ++++++++++++++++++++--- mlir/lib/Dialect/Linalg/EDSC/Builders.cpp | 132 +++++++++++++++------- mlir/test/EDSC/builder-api-test.cpp | 44 +++++++- 3 files changed, 246 insertions(+), 64 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h index fa813503103..4ee30c0e4a2 100644 --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -110,11 +110,14 @@ struct StructuredIndexed { operator Value() const /* implicit */ { return value; } ArrayRef getExprs() { return exprs; } + Type getType() { return value.getType(); } private: StructuredIndexed(Value v, ArrayRef indexings) : value(v), exprs(indexings.begin(), indexings.end()) { - assert(v.getType().isa() && "MemRefType expected"); + assert((v.getType().isa() || + v.getType().isa()) && + "MemRef or RankedTensor expected"); } StructuredIndexed(ValueHandle v, ArrayRef indexings) : StructuredIndexed(v.getValue(), indexings) {} @@ -125,9 +128,21 @@ private: inline void defaultRegionBuilder(ArrayRef args) {} +/// Build a `linalg.generic` op with the specified inputs, outputs and region. +/// +/// `otherValues` and `otherAttributes` may be passed and will be appended as +/// operands and attributes respectively. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputs, ArrayRef resultTensorTypes = {}, function_ref)> regionBuilder = defaultRegionBuilder, ArrayRef otherValues = {}, ArrayRef otherAttributes = {}); @@ -167,32 +182,77 @@ void macRegionBuilder(ArrayRef args); /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, - StructuredIndexed I, StructuredIndexed O); + StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = tanh(I)`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. -Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Binary pointwise operation (with broadcast) entry point. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. using BinaryPointwiseOpBuilder = function_ref; Operation *linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = I1 + I2`. The client is responsible for specifying the proper /// indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_add(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); /// Build a linalg.pointwise with all `parallel` iterators and a region that /// computes `O = max(I!, I2)`. The client is responsible for specifying the /// proper indexings when creating the StructuredIndexed. +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O); + StructuredIndexed O, + ArrayRef resultTensorTypes = {}); // TODO(ntv): Implement more useful pointwise operations on a per-need basis. @@ -203,11 +263,23 @@ Operation *linalg_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, /// | /// | C(m, n) += A(m, k) * B(k, n) /// ``` -Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC); +/// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +Operation *linalg_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, + ArrayRef resultTensorTypes = {}); -template Operation *linalg_matmul(Container values) { +template +Operation *linalg_matmul(Container values, + ArrayRef resultTensorTypes = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_matmul(values[0], values[1], values[2]); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_matmul(values[0], values[1], values[2], resultTensorTypes); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -231,16 +303,28 @@ template Operation *linalg_matmul(Container values) { /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes = {}, ArrayRef strides = {}, ArrayRef dilations = {}); template -Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, - ArrayRef dilations = {}) { +Operation * +linalg_conv_nhwc(Container values, ArrayRef resultTensorTypes = {}, + ArrayRef strides = {}, ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); - return linalg_conv_nhwc(values[0], values[1], values[2], strides, dilations); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); + return linalg_conv_nhwc(values[0], values[1], values[2], resultTensorTypes, + strides, dilations); } /// Build a linalg.generic, under the current ScopedContext, at the current @@ -249,7 +333,7 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, /// (batch, dm, c, [h, w, ...], [kh, kw, ...]) = /// | (par, par, par, [par, par, ...], [red, red, ...]) /// | -/// | O(batch, [h, w, ...], c * depth_multiplier) += +/// | O(batch, [h, w, ...], c * depthMultiplier) += /// | I(batch, /// | [ /// | stride[0] * h + dilations[0] * kh, @@ -257,26 +341,40 @@ Operation *linalg_conv_nhwc(Container values, ArrayRef strides = {}, /// ], /// | c) /// | * -/// | W([kh, kw, ...], c, depth_multiplier) +/// | W([kh, kw, ...], c, depthMultiplier) /// ``` /// If `dilations` or `strides` are left empty, the default value of `1` is used /// along each relevant dimension. /// /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// +/// This accepts both buffers and tensors as `inputs` but only buffers as +/// `outputs`. Output tensors can be specified with `resultTensorTypes`, in +/// which case, the canonical identity indexing_map is assumed. +// +// TODO(ntv) In the future we may want to relax this identity assumption (e.g. +// for automatic differentiation purposes). In that case we will want to make +// StructuredIndexed work with ValueHandle to encode type or value. +// // TODO(ntv) Extend convolution rank with some template magic. Operation *linalg_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, int depth_multiplier = 1, + ValueHandle vO, + ArrayRef resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef strides = {}, ArrayRef dilations = {}); template -Operation *linalg_dilated_conv_nhwc(Container values, int depth_multiplier, +Operation *linalg_dilated_conv_nhwc(Container values, + ArrayRef resultTensorTypes = {}, + int depthMultiplier = 1, ArrayRef strides = {}, ArrayRef dilations = {}) { assert(values.size() == 3 && "Expected exactly 3 values"); + assert(resultTensorTypes.size() <= 1 && "Expected at most 1 result tensor"); return linalg_dilated_conv_nhwc(values[0], values[1], values[2], - depth_multiplier, strides, dilations); + resultTensorTypes, depthMultiplier, strides, + dilations); } } // namespace ops diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 0940f564b2e..395a409c00e 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -128,16 +128,20 @@ static void getMaxDimIndex(ArrayRef structuredIndices, Operation *mlir::edsc::makeGenericLinalgOp( ArrayRef iteratorTypes, ArrayRef inputs, - ArrayRef outputs, + ArrayRef outputBuffers, ArrayRef resultTensorTypes, function_ref)> regionBuilder, ArrayRef otherValues, ArrayRef otherAttributes) { + assert( + llvm::all_of(llvm::make_range(outputBuffers.begin(), outputBuffers.end()), + [](Value v) { return v.getType().isa(); }) && + "output operands must all be buffers."); auto &builder = edsc::ScopedContext::getBuilder(); auto *ctx = builder.getContext(); unsigned nInputs = inputs.size(); - unsigned nOutputs = outputs.size(); + unsigned nOutputs = outputBuffers.size() + resultTensorTypes.size(); unsigned maxPos = 0; getMaxDimIndex(inputs, maxPos); - getMaxDimIndex(outputs, maxPos); + getMaxDimIndex(outputBuffers, maxPos); // maxPos is 0 indexed, need to turn this into a count (i.e. +1) unsigned nDims = maxPos + 1; @@ -146,7 +150,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( for (auto in : inputs) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, in.getExprs())); - for (auto out : outputs) + for (auto out : outputBuffers) maps.push_back( AffineMap::get(/*dimCount=*/nDims, /*symbolCount=*/0, out.getExprs())); @@ -154,7 +158,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( SmallVector values; values.reserve(nViews); values.append(inputs.begin(), inputs.end()); - values.append(outputs.begin(), outputs.end()); + values.append(outputBuffers.begin(), outputBuffers.end()); auto iteratorStrTypes = functional::map(toString, iteratorTypes); // clang-format off @@ -162,7 +166,7 @@ Operation *mlir::edsc::makeGenericLinalgOp( edsc::ScopedContext::getBuilder() .create( edsc::ScopedContext::getLocation(), - ArrayRef{}, // TODO(ntv): support tensors + resultTensorTypes, values, IntegerAttr::get(IntegerType::get(64, ctx), nInputs), IntegerAttr::get(IntegerType::get(64, ctx), nOutputs), @@ -207,7 +211,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef resultTensorTypes) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&unaryOp](ArrayRef args) { @@ -215,22 +220,30 @@ Operation *mlir::edsc::ops::linalg_pointwise(UnaryPointwiseOpBuilder unaryOp, ValueHandle a(args[0]); linalg_yield(unaryOp(a)); }; - return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); + + // Distinguish between tensor and buffer semantics. + if (O.getType().isa()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I, O}, {}, resultTensorTypes, fun); } -Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O, + ArrayRef resultTensorTypes) { ; using edsc::intrinsics::tanh; UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); - return linalg_pointwise(unOp, I, O); + return linalg_pointwise(unOp, I, O, resultTensorTypes); } /// Binary pointwise operation (with broadcast) entry point. Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, - StructuredIndexed O) { + StructuredIndexed O, + ArrayRef resultTensorTypes) { SmallVector iterTypes(O.getExprs().size(), edsc::IterType::Parallel); auto fun = [&binaryOp](ArrayRef args) { @@ -238,45 +251,62 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp, ValueHandle a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; - return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); + // Distinguish between tensor and buffer semantics. + if (O.getType().isa()) { + assert(resultTensorTypes.empty()); + return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, {}, fun); + } + return makeGenericLinalgOp(iterTypes, {I1, I2, O}, {}, resultTensorTypes, + fun); } -Operation *mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_add(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef resultTensorTypes) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } -Operation *mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, - StructuredIndexed I2, - StructuredIndexed O) { +Operation * +mlir::edsc::ops::linalg_pointwise_max(StructuredIndexed I1, + StructuredIndexed I2, StructuredIndexed O, + ArrayRef resultTensorTypes) { BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { using edsc::intrinsics::select; using edsc::op::operator>; return select(a > b, a, b).getValue(); }); - return linalg_pointwise(binOp, I1, I2, O); + return linalg_pointwise(binOp, I1, I2, O, resultTensorTypes); } Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC) { - // clang-format off + ValueHandle vC, + ArrayRef resultTensorTypes) { AffineExpr m, n, k; bindDims(ScopedContext::getContext(), m, n, k); StructuredIndexed A(vA), B(vB), C(vC); + + assert(!C.getType().isa() || resultTensorTypes.empty()); + StructuredIndexed allIndexed[3]{A({m, k}), B({k, n}), C({m, n})}; + ArrayRef inputs = + (C.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (C.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; return makeGenericLinalgOp( - {IterType::Parallel, IterType::Parallel, IterType::Reduction}, - {A({m, k}), B({k, n})}, - {C({m, n})}, - macRegionBuilder); - // clang-format on + {IterType::Parallel, IterType::Parallel, IterType::Reduction}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); @@ -294,23 +324,33 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW, bindDims(ctx, b, f, h, w, kh, kw, c); unsigned numDims = c.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); + + assert(!O.getType().isa() || resultTensorTypes.empty()); + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. // clang-format off - return makeGenericLinalgOp( - {par, par, par, par, red, red, red}, { + StructuredIndexed allIndexed[3] = { I({b, - // Roundtrip to flattened form to serve as canonicalization and ensure - // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, f})}, { - O({b, h, w, f})}, - macRegionBuilder); + W({kh, kw, c, f}), + O({b, h, w, f})}; // clang-format on + auto inputs = (O.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (O.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; + return makeGenericLinalgOp({par, par, par, par, red, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( - ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, + ValueHandle vI, ValueHandle vW, ValueHandle vO, + ArrayRef resultTensorTypes, int depthMultiplier, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); // TODO(ntv) some template magic to make everything rank-polymorphic. @@ -328,16 +368,26 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc( bindDims(ctx, b, dm, c, h, w, kh, kw); unsigned numDims = kw.cast().getPosition() + 1; StructuredIndexed I(vI), W(vW), O(vO); - return makeGenericLinalgOp( - {par, par, par, par, par, red, red}, { + // Roundtrip to flattened form to serve as canonicalization and ensure + // consistent ordering of subexpressions. + // clang-format off + StructuredIndexed allIndexed[3] = { I({b, // Roundtrip to flattened form to serve as canonicalization and ensure // consistent ordering of subexpressions. simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0), simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0), c}), - W({kh, kw, c, dm})}, { - O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})}, - macRegionBuilder); + W({kh, kw, c, dm}), + O({b, h, w, simplifyAffineExpr(c * depthMultiplier + dm, numDims, 0)})}; // clang-format on + auto inputs = (O.getType().isa()) + ? ArrayRef{allIndexed, allIndexed + 2} + : ArrayRef{allIndexed, allIndexed + 3}; + ArrayRef outputs = + (O.getType().isa()) + ? ArrayRef{allIndexed + 2, allIndexed + 3} + : ArrayRef{}; + return makeGenericLinalgOp({par, par, par, par, par, red, red}, inputs, + outputs, resultTensorTypes, macRegionBuilder); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 5388446b253..f3ca26176b1 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -467,7 +467,8 @@ TEST_FUNC(zero_and_sign_extendi_op_i1_to_i8) { auto i1Type = IntegerType::get(1, &globalContext()); auto i8Type = IntegerType::get(8, &globalContext()); auto memrefType = MemRefType::get({}, i1Type, {}, 0); - auto f = makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); + auto f = + makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); @@ -795,10 +796,12 @@ TEST_FUNC(empty_map_load_store) { } // CHECK-LABEL: func @affine_if_op -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} // CHECK-NOT: else -// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}}[[s0:.*]], [[s1:.*]]{{\]}} -// CHECK-NEXT: } else { +// CHECK: affine.if affine_set<([[d0:.*]], [[d1:.*]]){{\[}} +// CHECK-SAME: [[s0:.*]], [[s1:.*]]{{\]}} +// CHECK-NEXT: } else { TEST_FUNC(affine_if_op) { using namespace edsc; using namespace edsc::intrinsics; @@ -900,6 +903,36 @@ TEST_FUNC(linalg_matmul_test) { } // clang-format off +// CHECK-LABEL: func @linalg_matmul_mixed_tensors +// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64, +// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32): +// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32 +// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32 +// CHECK: linalg.yield %[[a4]] : f32 +// CHECK: }: tensor, memref, tensor -> tensor +// clang-format on +TEST_FUNC(linalg_matmul_mixed_tensors_test) { + using namespace edsc; + using namespace edsc::ops; + + auto f32Type = FloatType::getF32(&globalContext()); + auto memrefType = MemRefType::get({-1, -1}, f32Type, {}, 0); + auto tensorType = RankedTensorType::get({-1, -1}, f32Type); + auto f = makeFunction("linalg_matmul_mixed_tensors", {}, + {tensorType, memrefType, tensorType}); + + OpBuilder builder(f.getBody()); + ScopedContext scope(builder, f.getLoc()); + linalg_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + tensorType); + + f.print(llvm::outs()); + f.erase(); +} + +// clang-format off // CHECK-LABEL: func @linalg_conv_nhwc // CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64, // CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>, @@ -923,7 +956,7 @@ TEST_FUNC(linalg_conv_nhwc) { OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), + linalg_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), {}, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); @@ -956,6 +989,7 @@ TEST_FUNC(linalg_dilated_conv_nhwc) { ScopedContext scope(builder, f.getLoc()); linalg_dilated_conv_nhwc( makeValueHandles(llvm::to_vector<3>(f.getArguments())), + /*outputTensorTypes=*/{}, /*depth_multiplier=*/7, /*strides=*/{3, 4}, /*dilations=*/{5, 6}); -- 2.11.0