From: Alex Zinenko Date: Fri, 26 Jun 2020 12:34:00 +0000 (+0200) Subject: [mlir] support returning unranked memrefs X-Git-Url: http://git.osdn.net/view?a=commitdiff_plain;h=6323065fd6026de926b15bb609f4601e366a300c;p=android-x86%2Fexternal-llvm-project.git [mlir] support returning unranked memrefs Initially, unranked memref descriptors in the LLVM dialect were designed only to be passed into functions. An assertion was guarding against returning unranked memrefs from functions in the standard-to-LLVM conversion. This is insufficient for functions that wish to return an unranked memref such that the caller does not know the rank in advance, and hence cannot allocate the descriptor and pass it in as an argument. Introduce a calling convention for returning unranked memref descriptors as follows. An unranked memref descriptor always points to a ranked memref descriptor stored on stack of the current function. When an unranked memref descriptor is returned from a function, the ranked memref descriptor it points to is copied to dynamically allocated memory, the ownership of which is transferred to the caller. The caller is responsible for deallocating the dynamically allocated memory and for copying the pointed-to ranked memref descriptor onto its stack. Provide default lowerings for std.return, std.call and std.indirect_call that maintain the conversion defined above. This convention is additionally exercised by a runtime test to guard against memory errors. Differential Revision: https://reviews.llvm.org/D82647 --- diff --git a/mlir/docs/ConversionToLLVMDialect.md b/mlir/docs/ConversionToLLVMDialect.md index 15d09cb4b87..e65df4444b8 100644 --- a/mlir/docs/ConversionToLLVMDialect.md +++ b/mlir/docs/ConversionToLLVMDialect.md @@ -246,7 +246,7 @@ func @bar() { } ``` -### Calling Convention for `memref` +### Calling Convention for Ranked `memref` Function _arguments_ of `memref` type, ranked or unranked, are _expanded_ into a list of arguments of non-aggregate types that the memref descriptor defined @@ -317,7 +317,9 @@ llvm.func @bar() { ``` -For **unranked** memrefs, the list of function arguments always contains two +### Calling Convention for Unranked `memref` + +For unranked memrefs, the list of function arguments always contains two elements, same as the unranked memref descriptor: an integer rank, and a type-erased (`!llvm<"i8*">`) pointer to the ranked memref descriptor. Note that while the _calling convention_ does not require stack allocation, _casting_ to @@ -369,6 +371,20 @@ llvm.func @bar() { } ``` +**Lifetime.** The second element of the unranked memref descriptor points to +some memory in which the ranked memref descriptor is stored. By convention, this +memory is allocated on stack and has the lifetime of the function. (*Note:* due +to function-length lifetime, creation of multiple unranked memref descriptors, +e.g., in a loop, may lead to stack overflows.) If an unranked descriptor has to +be returned from a function, the ranked descriptor it points to is copied into +dynamically allocated memory, and the pointer in the unranked descriptor is +updated accodingly. The allocation happens immediately before returning. It is +the responsibility of the caller to free the dynamically allocated memory. The +default conversion of `std.call` and `std.call_indirect` copies the ranked +descriptor to newly allocated memory on the caller's stack. Thus, the convention +of the ranked memref descriptor pointed to by an unranked memref descriptor +being stored on stack is respected. + *This convention may or may not apply if the conversion of MemRef types is overridden by the user.* diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index a7e4ff2f52c..c96341094af 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -129,6 +129,9 @@ public: /// Gets the bitwidth of the index type when converted to LLVM. unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; } + /// Gets the pointer bitwidth. + unsigned getPointerBitwidth(unsigned addressSpace = 0); + protected: /// LLVM IR module used to parse/create types. llvm::Module *module; @@ -386,6 +389,13 @@ public: /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues() { return 2; } + + /// Builds IR computing the sizes in bytes (suitable for opaque allocation) + /// and appends the corresponding values into `sizes`. + static void computeSizes(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + ArrayRef values, + SmallVectorImpl &sizes); }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 30e34440c2d..3d8e52cecf9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -794,6 +794,13 @@ def LLVM_PowOp : LLVM_BinarySameArgsIntrinsicOp<"pow">; def LLVM_BitReverseOp : LLVM_UnaryIntrinsicOp<"bitreverse">; def LLVM_CtPopOp : LLVM_UnaryIntrinsicOp<"ctpop">; +def LLVM_MemcpyOp : LLVM_ZeroResultIntrOp<"memcpy", [0, 1, 2]>, + Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src, + LLVM_Type:$len, LLVM_Type:$isVolatile)>; +def LLVM_MemcpyInlineOp : LLVM_ZeroResultIntrOp<"memcpy.inline", [0, 1, 2]>, + Arguments<(ins LLVM_Type:$dst, LLVM_Type:$src, + LLVM_Type:$len, LLVM_Type:$isVolatile)>; + // // Vector Reductions. // diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 19c451fa3fe..9376d53dc99 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" @@ -184,6 +185,10 @@ LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth()); } +unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { + return module->getDataLayout().getPointerSizeInBits(addressSpace); +} + Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } @@ -769,6 +774,51 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, results.push_back(d.memRefDescPtr(builder, loc)); } +void UnrankedMemRefDescriptor::computeSizes( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + ArrayRef values, SmallVectorImpl &sizes) { + if (values.empty()) + return; + + // Cache the index type. + LLVM::LLVMType indexType = typeConverter.getIndexType(); + + // Initialize shared constants. + Value one = createIndexAttrConstant(builder, loc, indexType, 1); + Value two = createIndexAttrConstant(builder, loc, indexType, 2); + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); + Value indexSize = + createIndexAttrConstant(builder, loc, indexType, + ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); + + sizes.reserve(sizes.size() + values.size()); + for (UnrankedMemRefDescriptor desc : values) { + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value doublePointerSize = + builder.create(loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = builder.create(loc, indexType, two, rank); + Value doubleRankIncremented = + builder.create(loc, indexType, doubleRank, one); + Value rankIndexSize = builder.create( + loc, indexType, doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = builder.create( + loc, indexType, doublePointerSize, rankIndexSize); + sizes.push_back(allocationSize); + } +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } @@ -1863,6 +1913,104 @@ struct AllocOpLowering : public AllocLikeOpLowering { using AllocaOpLowering = AllocLikeOpLowering; +/// Copies the shaped descriptor part to (if `toDynamic` is set) or from +/// (otherwise) the dynamically allocated memory for any operands that were +/// unranked descriptors originally. +static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + TypeRange origTypes, + SmallVectorImpl &operands, + bool toDynamic) { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + + // Find operands of unranked memref type and store them. + SmallVector unrankedMemrefs; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (!origTypes[i].isa()) + continue; + unrankedMemrefs.emplace_back(operands[i]); + } + + if (unrankedMemrefs.empty()) + return success(); + + // Compute allocation sizes. + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, + unrankedMemrefs, sizes); + + // Get frequently used types. + auto voidType = LLVM::LLVMType::getVoidTy(typeConverter.getDialect()); + auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()); + auto i1Type = LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()); + LLVM::LLVMType indexType = typeConverter.getIndexType(); + + // Find the malloc and free, or declare them if necessary. + auto module = builder.getInsertionPoint()->getParentOfType(); + auto mallocFunc = module.lookupSymbol("malloc"); + if (!mallocFunc && toDynamic) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + mallocFunc = builder.create( + builder.getUnknownLoc(), "malloc", + LLVM::LLVMType::getFunctionTy( + voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); + } + auto freeFunc = module.lookupSymbol("free"); + if (!freeFunc && !toDynamic) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + freeFunc = builder.create( + builder.getUnknownLoc(), "free", + LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), + /*isVarArg=*/false)); + } + + // Initialize shared constants. + Value zero = + builder.create(loc, i1Type, builder.getBoolAttr(false)); + + unsigned unrankedMemrefPos = 0; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + Type type = origTypes[i]; + if (!type.isa()) + continue; + Value allocationSize = sizes[unrankedMemrefPos++]; + UnrankedMemRefDescriptor desc(operands[i]); + + // Allocate memory, copy, and free the source if necessary. + Value memory = + toDynamic + ? builder.create(loc, mallocFunc, allocationSize) + .getResult(0) + : builder.create(loc, voidPtrType, allocationSize, + /*alignment=*/0); + + Value source = desc.memRefDescPtr(builder, loc); + builder.create(loc, memory, source, allocationSize, zero); + if (!toDynamic) + builder.create(loc, freeFunc, source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = typeConverter.convertType(type); + if (!descriptorType) + return failure(); + auto updatedDesc = + UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + + operands[i] = updatedDesc; + } + + return success(); +} + // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template @@ -1882,13 +2030,6 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - for (Type resType : resultTypes) { - assert(!resType.isa() && - "Returning unranked memref is not supported. Pass result as an" - "argument instead."); - (void)resType; - } - if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) @@ -1900,25 +2041,25 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); - // If < 2 results, packing did not do anything and we can just return. - if (numResults < 2) { - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } - - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around - // a particular interaction between MemRefType and CallOp lowering. Find a - // way to avoid special casing. SmallVector results; - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - auto type = this->typeConverter.convertType(op->getResult(i).getType()); - results.push_back(rewriter.create( - op->getLoc(), type, newOp.getOperation()->getResult(0), - rewriter.getI64ArrayAttr(i))); + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newOp.result_begin(), newOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = this->typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp.getOperation()->getResult(0), + rewriter.getI64ArrayAttr(i))); + } } + if (failed(copyUnrankedDescriptors( + rewriter, op->getLoc(), this->typeConverter, op->getResultTypes(), + results, /*toDynamic=*/false))) + return failure(); rewriter.replaceOp(op, results); return success(); @@ -2397,6 +2538,10 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); + auto updatedOperands = llvm::to_vector<4>(operands); + copyUnrankedDescriptors(rewriter, op->getLoc(), typeConverter, + op->getOperands().getTypes(), updatedOperands, + /*toDynamic=*/true); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { @@ -2406,7 +2551,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { } if (numArguments == 1) { rewriter.replaceOpWithNewOp( - op, ArrayRef(), operands.front(), op->getAttrs()); + op, ArrayRef(), updatedOperands, op->getAttrs()); return success(); } @@ -2418,7 +2563,7 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern { Value packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( - op->getLoc(), packedType, packed, operands[i], + op->getLoc(), packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir index 87bdab2680f..e17bf3e2422 100644 --- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir @@ -109,3 +109,134 @@ func @other_callee(%arg0: memref, %arg1: index) attributes { llvm.emit_c_ // EMIT_C_ATTRIBUTE: @_mlir_ciface_other_callee // EMIT_C_ATTRIBUTE: llvm.call @other_callee + +//===========================================================================// +// Calling convention on returning unranked memrefs. +//===========================================================================// + +// CHECK-LABEL: llvm.func @return_var_memref_caller +func @return_var_memref_caller(%arg0: memref<4x3xf32>) { + // CHECK: %[[CALL_RES:.*]] = llvm.call @return_var_memref + %0 = call @return_var_memref(%arg0) : (memref<4x3xf32>) -> memref<*xf32> + + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // These sizes may depend on the data layout, not matching specific values. + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + + // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] + // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] + // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] + // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOC_SIZE]] x !llvm.i8 + // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[CALL_RES]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]]) + // CHECK: llvm.call @free(%[[SOURCE]]) + // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[DESC]][0] + // CHECK: llvm.insertvalue %[[ALLOCA]], %[[DESC_1]][1] + return +} + +// CHECK-LABEL: llvm.func @return_var_memref +func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { + // Match the construction of the unranked descriptor. + // CHECK: %[[ALLOCA:.*]] = llvm.alloca + // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] + // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0] + // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] + %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32> + + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) + // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) + // These sizes may depend on the data layout, not matching specific values. + // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant + // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant + + // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] + // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] + // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] + // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) + // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]]) + // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]]) + // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm<"{ i64, i8* }"> + // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0] + // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1] + // CHECL: llvm.return %[[NEW_DESC_2]] + return %0 : memref<*xf32> +} + +// CHECK-LABEL: llvm.func @return_two_var_memref_caller +func @return_two_var_memref_caller(%arg0: memref<4x3xf32>) { + // Only check that we create two different descriptors using different + // memory, and deallocate both sources. The size computation is same as for + // the single result. + // CHECK: %[[CALL_RES:.*]] = llvm.call @return_two_var_memref + // CHECK: %[[RES_1:.*]] = llvm.extractvalue %[[CALL_RES]][0] + // CHECK: %[[RES_2:.*]] = llvm.extractvalue %[[CALL_RES]][1] + %0:2 = call @return_two_var_memref(%arg0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) + + // CHECK: %[[ALLOCA_1:.*]] = llvm.alloca %{{.*}} x !llvm.i8 + // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[RES_1:.*]][1] : ![[DESC_TYPE:.*]] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]]) + // CHECK: llvm.call @free(%[[SOURCE_1]]) + // CHECK: %[[DESC_1:.*]] = llvm.mlir.undef : ![[DESC_TYPE]] + // CHECK: %[[DESC_11:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_1]][0] + // CHECK: llvm.insertvalue %[[ALLOCA_1]], %[[DESC_11]][1] + + // CHECK: %[[ALLOCA_2:.*]] = llvm.alloca %{{.*}} x !llvm.i8 + // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[RES_2:.*]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]]) + // CHECK: llvm.call @free(%[[SOURCE_2]]) + // CHECK: %[[DESC_2:.*]] = llvm.mlir.undef : ![[DESC_TYPE]] + // CHECK: %[[DESC_21:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_2]][0] + // CHECK: llvm.insertvalue %[[ALLOCA_2]], %[[DESC_21]][1] + return +} + +// CHECK-LABEL: llvm.func @return_two_var_memref +func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { + // Match the construction of the unranked descriptor. + // CHECK: %[[ALLOCA:.*]] = llvm.alloca + // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] + // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm<"{ i64, i8* }"> + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0] + // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] + %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32> + + // Only check that we allocate the memory for each operand of the "return" + // separately, even if both operands are the same value. The calling + // convention requires the caller to free them and the caller cannot know + // whether they are the same value or not. + // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}}) + // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]]) + // CHECK: %[[RES_1:.*]] = llvm.mlir.undef + // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0] + // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1] + + // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}}) + // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[DESC_2]][1] + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]]) + // CHECK: %[[RES_2:.*]] = llvm.mlir.undef + // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0] + // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1] + + // CHECK: %[[RESULTS:.*]] = llvm.mlir.undef : !llvm<"{ { i64, i8* }, { i64, i8* } }"> + // CHECK: %[[RESULTS_1:.*]] = llvm.insertvalue %[[RES_12]], %[[RESULTS]] + // CHECK: %[[RESULTS_2:.*]] = llvm.insertvalue %[[RES_22]], %[[RESULTS_1]] + // CHECK: llvm.return %[[RESULTS_2]] + return %0, %0 : memref<*xf32>, memref<*xf32> +} + diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir index e782d5de1aa..a6ce8d9e219 100644 --- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir +++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir @@ -1,7 +1,9 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s -// CHECK-LABEL: func @ops(%arg0: !llvm.i32, %arg1: !llvm.float) -func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) { +// CHECK-LABEL: func @ops +func @ops(%arg0: !llvm.i32, %arg1: !llvm.float, + %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">, + %arg4: !llvm.i1) { // Integer arithmetic binary operations. // // CHECK-NEXT: %0 = llvm.add %arg0, %arg0 : !llvm.i32 @@ -109,6 +111,17 @@ func @ops(%arg0 : !llvm.i32, %arg1 : !llvm.float) { // CHECK: "llvm.intr.ctpop"(%{{.*}}) : (!llvm.i32) -> !llvm.i32 %33 = "llvm.intr.ctpop"(%arg0) : (!llvm.i32) -> !llvm.i32 +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + +// CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + +// CHECK: %[[SZ:.*]] = llvm.mlir.constant + %sz = llvm.mlir.constant(10: i64) : !llvm.i64 +// CHECK: "llvm.intr.memcpy.inline"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg4) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + // CHECK: llvm.return llvm.return } @@ -315,4 +328,4 @@ func @useFenceInst() { // CHECK: release llvm.fence release return -} \ No newline at end of file +} diff --git a/mlir/test/Target/llvmir-intrinsics.mlir b/mlir/test/Target/llvmir-intrinsics.mlir index e04b40e916f..45292124aed 100644 --- a/mlir/test/Target/llvmir-intrinsics.mlir +++ b/mlir/test/Target/llvmir-intrinsics.mlir @@ -202,6 +202,17 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>"> llvm.return } +// CHECK-LABEL: @memcpy_test +llvm.func @memcpy_test(%arg0: !llvm.i32, %arg1: !llvm.i1, %arg2: !llvm<"i8*">, %arg3: !llvm<"i8*">) { + // CHECK: call void @llvm.memcpy.p0i8.p0i8.i32(i8* %{{.*}}, i8* %{{.*}}, i32 %{{.*}}, i1 %{{.*}}) + "llvm.intr.memcpy"(%arg2, %arg3, %arg0, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32, !llvm.i1) -> () + %sz = llvm.mlir.constant(10: i64) : !llvm.i64 + // CHECK: call void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 10, i1 %{{.*}}) + "llvm.intr.memcpy.inline"(%arg2, %arg3, %sz, %arg1) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) -> () + llvm.return +} + + // Check that intrinsics are declared with appropriate types. // CHECK-DAG: declare float @llvm.fma.f32(float, float, float) // CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0 @@ -231,3 +242,5 @@ llvm.func @masked_intrinsics(%A: !llvm<"<7 x float>*">, %mask: !llvm<"<7 x i1>"> // CHECK-DAG: declare void @llvm.matrix.column.major.store.v48f32.p0f32(<48 x float>, float* nocapture writeonly, i64, i1 immarg, i32 immarg, i32 immarg) // CHECK-DAG: declare <7 x float> @llvm.masked.load.v7f32.p0v7f32(<7 x float>*, i32 immarg, <7 x i1>, <7 x float>) // CHECK-DAG: declare void @llvm.masked.store.v7f32.p0v7f32(<7 x float>, <7 x float>*, i32 immarg, <7 x i1>) +// CHECK-DAG: declare void @llvm.memcpy.p0i8.p0i8.i32(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i32, i1 immarg) +// CHECK-DAG: declare void @llvm.memcpy.inline.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64 immarg, i1 immarg) diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir index 0eb68ac0336..df760f593db 100644 --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -18,6 +18,21 @@ // CHECK: rank = 0 // 122 is ASCII for 'z'. // CHECK: [z] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] +// +// CHECK: rank = 2 +// CHECK-SAME: sizes = [4, 3] +// CHECK-SAME: strides = [3, 1] +// CHECK-COUNT-4: [1, 1, 1] func @main() -> () { %A = alloc() : memref<10x3xf32, 0> %f2 = constant 2.00000e+00 : f32 @@ -48,8 +63,40 @@ func @main() -> () { call @print_memref_i8(%U4) : (memref<*xi8>) -> () dealloc %A : memref<10x3xf32, 0> + + call @return_var_memref_caller() : () -> () + call @return_two_var_memref_caller() : () -> () return } func @print_memref_i8(memref<*xi8>) attributes { llvm.emit_c_interface } func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @return_two_var_memref_caller() { + %0 = alloca() : memref<4x3xf32> + %c0f32 = constant 1.0 : f32 + linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32 + %1:2 = call @return_two_var_memref(%0) : (memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) + call @print_memref_f32(%1#0) : (memref<*xf32>) -> () + call @print_memref_f32(%1#1) : (memref<*xf32>) -> () + return + } + + func @return_two_var_memref(%arg0: memref<4x3xf32>) -> (memref<*xf32>, memref<*xf32>) { + %0 = memref_cast %arg0 : memref<4x3xf32> to memref<*xf32> + return %0, %0 : memref<*xf32>, memref<*xf32> +} + +func @return_var_memref_caller() { + %0 = alloca() : memref<4x3xf32> + %c0f32 = constant 1.0 : f32 + linalg.fill(%0, %c0f32) : memref<4x3xf32>, f32 + %1 = call @return_var_memref(%0) : (memref<4x3xf32>) -> memref<*xf32> + call @print_memref_f32(%1) : (memref<*xf32>) -> () + return +} + +func @return_var_memref(%arg0: memref<4x3xf32>) -> memref<*xf32> { + %0 = memref_cast %arg0: memref<4x3xf32> to memref<*xf32> + return %0 : memref<*xf32> +}