static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat";
static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat";
static constexpr const char *kBindMemRef3DFloat = "bindMemRef3DFloat";
+static constexpr const char *kBindMemRef1DInt = "bindMemRef1DInt";
+static constexpr const char *kBindMemRef2DInt = "bindMemRef2DInt";
+static constexpr const char *kBindMemRef3DInt = "bindMemRef3DInt";
static constexpr const char *kCInterfaceVulkanLaunch =
"_mlir_ciface_vulkanLaunch";
static constexpr const char *kDeinitVulkan = "deinitVulkan";
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
- llvmMemRef1DFloat = getMemRefType(1);
- llvmMemRef2DFloat = getMemRefType(2);
- llvmMemRef3DFloat = getMemRefType(3);
+ llvmMemRef1DFloat = getMemRefType(1, llvmFloatType);
+ llvmMemRef2DFloat = getMemRefType(2, llvmFloatType);
+ llvmMemRef3DFloat = getMemRefType(3, llvmFloatType);
+ llvmMemRef1DInt = getMemRefType(1, llvmInt32Type);
+ llvmMemRef2DInt = getMemRefType(2, llvmInt32Type);
+ llvmMemRef3DInt = getMemRefType(3, llvmInt32Type);
}
- LLVM::LLVMType getMemRefType(uint32_t rank) {
+ LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
// According to the MLIR doc memref argument is converted into a
// pointer-to-struct argument of type:
// template <typename Elem, size_t Rank>
// int64_t sizes[Rank]; // omitted when rank == 0
// int64_t strides[Rank]; // omitted when rank == 0
// };
- auto llvmPtrToFloatType = getFloatType().getPointerTo();
+ auto llvmPtrToElementType = elemenType.getPointerTo();
auto llvmArrayRankElementSizeType =
LLVM::LLVMType::getArrayTy(getInt64Type(), rank);
// Create a type
- // `!llvm<"{ float*, float*, i64, [`rank` x i64], [`rank` x i64]}">`.
+ // `!llvm<"{ `element-type`*, `element-type`*, i64,
+ // [`rank` x i64], [`rank` x i64]}">`.
return LLVM::LLVMType::getStructTy(
llvmDialect,
- {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(),
+ {llvmPtrToElementType, llvmPtrToElementType, getInt64Type(),
llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
}
LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; }
LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; }
LLVM::LLVMType getMemRef3DFloat() { return llvmMemRef3DFloat; }
+ LLVM::LLVMType getMemRef1DInt() { return llvmMemRef1DInt; }
+ LLVM::LLVMType getMemRef2DInt() { return llvmMemRef2DInt; }
+ LLVM::LLVMType getMemRef3DInt() { return llvmMemRef3DInt; }
/// Creates a LLVM global for the given `name`.
Value createEntryPointNameConstant(StringRef name, Location loc,
/// Collects SPIRV attributes from the given `vulkanLaunchCallOp`.
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp);
- /// Deduces a rank from the given 'ptrToMemRefDescriptor`.
- LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank);
+ /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
+ LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
+ uint32_t &rank, LLVM::LLVMType &type);
+
+ /// Returns a string representation from the given `type`.
+ StringRef stringifyType(LLVM::LLVMType type) {
+ if (type.isFloatTy())
+ return "Float";
+ if (type.isIntegerTy())
+ return "Int";
+
+ llvm_unreachable("unsupported type");
+ }
public:
void runOnOperation() override;
LLVM::LLVMType llvmMemRef1DFloat;
LLVM::LLVMType llvmMemRef2DFloat;
LLVM::LLVMType llvmMemRef3DFloat;
+ LLVM::LLVMType llvmMemRef1DInt;
+ LLVM::LLVMType llvmMemRef2DInt;
+ LLVM::LLVMType llvmMemRef3DInt;
// TODO: Use an associative array to support multiple vulkan launch calls.
std::pair<StringAttr, StringAttr> spirvAttributes;
auto ptrToMemRefDescriptor = en.value();
uint32_t rank = 0;
- if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) {
+ LLVM::LLVMType type;
+ if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
cInterfaceVulkanLaunchCallOp.emitError()
<< "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
return signalPassFailure();
}
- auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str();
+ auto symbolName =
+ llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str();
// Create call to `bindMemRef`.
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getVoidType()},
}
}
-LogicalResult
-VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor,
- uint32_t &rank) {
+LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
+ Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
auto llvmPtrDescriptorTy =
ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMType>();
if (!llvmPtrDescriptorTy)
// };
if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy())
return failure();
+
+ type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy();
if (llvmDescriptorTy.getStructNumElements() == 3) {
rank = 0;
return success();
}
-
rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements();
return success();
}
/*isVarArg=*/false));
}
- if (!module.lookupSymbol(kBindMemRef1DFloat)) {
- builder.create<LLVM::LLVMFuncOp>(
- loc, kBindMemRef1DFloat,
- LLVM::LLVMType::getFunctionTy(getVoidType(),
- {getPointerType(), getInt32Type(),
- getInt32Type(),
- getMemRef1DFloat().getPointerTo()},
- /*isVarArg=*/false));
+#define CREATE_VULKAN_BIND_FUNC(MemRefType) \
+ if (!module.lookupSymbol(kBind##MemRefType)) { \
+ builder.create<LLVM::LLVMFuncOp>( \
+ loc, kBind##MemRefType, \
+ LLVM::LLVMType::getFunctionTy(getVoidType(), \
+ {getPointerType(), getInt32Type(), \
+ getInt32Type(), \
+ get##MemRefType().getPointerTo()}, \
+ /*isVarArg=*/false)); \
}
- if (!module.lookupSymbol(kBindMemRef2DFloat)) {
- builder.create<LLVM::LLVMFuncOp>(
- loc, kBindMemRef2DFloat,
- LLVM::LLVMType::getFunctionTy(getVoidType(),
- {getPointerType(), getInt32Type(),
- getInt32Type(),
- getMemRef2DFloat().getPointerTo()},
- /*isVarArg=*/false));
- }
-
- if (!module.lookupSymbol(kBindMemRef3DFloat)) {
- builder.create<LLVM::LLVMFuncOp>(
- loc, kBindMemRef3DFloat,
- LLVM::LLVMType::getFunctionTy(getVoidType(),
- {getPointerType(), getInt32Type(),
- getInt32Type(),
- getMemRef3DFloat().getPointerTo()},
- /*isVarArg=*/false));
- }
+ CREATE_VULKAN_BIND_FUNC(MemRef1DFloat);
+ CREATE_VULKAN_BIND_FUNC(MemRef2DFloat);
+ CREATE_VULKAN_BIND_FUNC(MemRef3DFloat);
+ CREATE_VULKAN_BIND_FUNC(MemRef1DInt);
+ CREATE_VULKAN_BIND_FUNC(MemRef2DInt);
+ CREATE_VULKAN_BIND_FUNC(MemRef3DInt);
if (!module.lookupSymbol(kInitVulkan)) {
builder.create<LLVM::LLVMFuncOp>(
--- /dev/null
+// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
+
+// CHECK-COUNT-64: [3, 3, 3, 3, 3, 3, 3, 3]
+module attributes {
+ gpu.container_module,
+ spv.target_env = #spv.target_env<
+ #spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
+} {
+ gpu.module @kernels {
+ gpu.func @kernel_addi(%arg0 : memref<8xi32>, %arg1 : memref<8x8xi32>, %arg2 : memref<8x8x8xi32>)
+ kernel attributes { spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
+ %x = "gpu.block_id"() {dimension = "x"} : () -> index
+ %y = "gpu.block_id"() {dimension = "y"} : () -> index
+ %z = "gpu.block_id"() {dimension = "z"} : () -> index
+ %0 = load %arg0[%x] : memref<8xi32>
+ %1 = load %arg1[%y, %x] : memref<8x8xi32>
+ %2 = addi %0, %1 : i32
+ store %2, %arg2[%z, %y, %x] : memref<8x8x8xi32>
+ gpu.return
+ }
+ }
+
+ func @main() {
+ %arg0 = alloc() : memref<8xi32>
+ %arg1 = alloc() : memref<8x8xi32>
+ %arg2 = alloc() : memref<8x8x8xi32>
+ %value0 = constant 0 : i32
+ %value1 = constant 1 : i32
+ %value2 = constant 2 : i32
+ %arg3 = memref_cast %arg0 : memref<8xi32> to memref<?xi32>
+ %arg4 = memref_cast %arg1 : memref<8x8xi32> to memref<?x?xi32>
+ %arg5 = memref_cast %arg2 : memref<8x8x8xi32> to memref<?x?x?xi32>
+ call @fillResource1DInt(%arg3, %value1) : (memref<?xi32>, i32) -> ()
+ call @fillResource2DInt(%arg4, %value2) : (memref<?x?xi32>, i32) -> ()
+ call @fillResource3DInt(%arg5, %value0) : (memref<?x?x?xi32>, i32) -> ()
+
+ %cst1 = constant 1 : index
+ %cst8 = constant 8 : index
+ "gpu.launch_func"(%cst8, %cst8, %cst8, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = @kernels::@kernel_addi }
+ : (index, index, index, index, index, index, memref<8xi32>, memref<8x8xi32>, memref<8x8x8xi32>) -> ()
+ %arg6 = memref_cast %arg5 : memref<?x?x?xi32> to memref<*xi32>
+ call @print_memref_i32(%arg6) : (memref<*xi32>) -> ()
+ return
+ }
+ func @fillResource1DInt(%0 : memref<?xi32>, %1 : i32)
+ func @fillResource2DInt(%0 : memref<?x?xi32>, %1 : i32)
+ func @fillResource3DInt(%0 : memref<?x?x?xi32>, %1 : i32)
+ func @print_memref_i32(%ptr : memref<*xi32>)
+}
+
->setResourceData(setIndex, bindIndex, memBuffer);
}
+/// Binds the given 1D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef1DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+ BindingIndex bindIndex,
+ MemRefDescriptor<int32_t, 1> *ptr) {
+ VulkanHostMemoryBuffer memBuffer{
+ ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(int32_t))};
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
+/// Binds the given 2D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef2DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+ BindingIndex bindIndex,
+ MemRefDescriptor<int32_t, 2> *ptr) {
+ VulkanHostMemoryBuffer memBuffer{
+ ptr->allocated,
+ static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] * sizeof(int32_t))};
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
+/// Binds the given 3D int memref to the given descriptor set and descriptor
+/// index.
+void bindMemRef3DInt(void *vkRuntimeManager, DescriptorSetIndex setIndex,
+ BindingIndex bindIndex,
+ MemRefDescriptor<int32_t, 3> *ptr) {
+ VulkanHostMemoryBuffer memBuffer{
+ ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * ptr->sizes[1] *
+ ptr->sizes[2] * sizeof(int32_t))};
+ reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
+ ->setResourceData(setIndex, bindIndex, memBuffer);
+}
+
/// Fills the given 1D float memref with the given float value.
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
+
+/// Fills the given 1D int memref with the given int value.
+void _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
+ int32_t value) {
+ std::fill_n(ptr->allocated, ptr->sizes[0], value);
+}
+
+/// Fills the given 2D int memref with the given int value.
+void _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
+ int32_t value) {
+ std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
+}
+
+/// Fills the given 3D int memref with the given int value.
+void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
+ int32_t value) {
+ std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
+ value);
+}
}