\r
struct ComputeParams\r
{\r
+ size_t numElements;\r
int localSizeX;\r
int localSizeY;\r
int localSizeZ;\r
-};\r
\r
-class SwiftShaderVulkanComputeTest : public testing::TestWithParam<ComputeParams> {};\r
+ friend std::ostream& operator<<(std::ostream& os, const ComputeParams& params) {\r
+ return os << "ComputeParams{" <<\r
+ "numElements: " << params.numElements << ", " <<\r
+ "localSizeX: " << params.localSizeX << ", " <<\r
+ "localSizeY: " << params.localSizeY << ", " <<\r
+ "localSizeZ: " << params.localSizeZ <<\r
+ "}";\r
+ }\r
+};\r
\r
-INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanComputeTest, testing::Values(\r
- ComputeParams{1, 1, 1},\r
- ComputeParams{2, 1, 1},\r
- ComputeParams{4, 1, 1},\r
- ComputeParams{8, 1, 1},\r
- ComputeParams{16, 1, 1},\r
- ComputeParams{32, 1, 1}\r
-));\r
+// Base class for compute tests that read from an input buffer and write to an\r
+// output buffer of same length.\r
+class SwiftShaderVulkanBufferToBufferComputeTest : public testing::TestWithParam<ComputeParams>\r
+{\r
+public:\r
+ void test(const std::string& shader,\r
+ std::function<uint32_t(uint32_t idx)> input,\r
+ std::function<uint32_t(uint32_t idx)> expected);\r
+};\r
\r
-TEST_P(SwiftShaderVulkanComputeTest, Memcpy)\r
+void SwiftShaderVulkanBufferToBufferComputeTest::test(\r
+ const std::string& shader,\r
+ std::function<uint32_t(uint32_t idx)> input,\r
+ std::function<uint32_t(uint32_t idx)> expected)\r
{\r
+ auto code = compileSpirv(shader.c_str());\r
+\r
Driver driver;\r
ASSERT_TRUE(driver.loadSwiftShader());\r
\r
- auto params = GetParam();\r
-\r
- std::stringstream src;\r
- src <<\r
- "OpCapability Shader\n"\r
- "OpMemoryModel Logical GLSL450\n"\r
- "OpEntryPoint GLCompute %1 \"main\" %2\n"\r
- "OpExecutionMode %1 LocalSize " <<\r
- params.localSizeX << " " <<\r
- params.localSizeY << " " <<\r
- params.localSizeZ << "\n" <<\r
- "OpDecorate %3 ArrayStride 4\n"\r
- "OpMemberDecorate %4 0 Offset 0\n"\r
- "OpDecorate %4 BufferBlock\n"\r
- "OpDecorate %5 DescriptorSet 0\n"\r
- "OpDecorate %5 Binding 1\n"\r
- "OpDecorate %2 BuiltIn GlobalInvocationId\n"\r
- "OpDecorate %6 ArrayStride 4\n"\r
- "OpMemberDecorate %7 0 Offset 0\n"\r
- "OpDecorate %7 BufferBlock\n"\r
- "OpDecorate %8 DescriptorSet 0\n"\r
- "OpDecorate %8 Binding 0\n"\r
- "%9 = OpTypeVoid\n"\r
- "%10 = OpTypeFunction %9\n"\r
- "%11 = OpTypeInt 32 1\n"\r
- "%3 = OpTypeRuntimeArray %11\n"\r
- "%4 = OpTypeStruct %3\n"\r
- "%12 = OpTypePointer Uniform %4\n"\r
- "%5 = OpVariable %12 Uniform\n"\r
- "%13 = OpConstant %11 0\n"\r
- "%14 = OpTypeInt 32 0\n"\r
- "%15 = OpTypeVector %14 3\n"\r
- "%16 = OpTypePointer Input %15\n"\r
- "%2 = OpVariable %16 Input\n"\r
- "%17 = OpConstant %14 0\n"\r
- "%18 = OpTypePointer Input %14\n"\r
- "%6 = OpTypeRuntimeArray %11\n"\r
- "%7 = OpTypeStruct %6\n"\r
- "%19 = OpTypePointer Uniform %7\n"\r
- "%8 = OpVariable %19 Uniform\n"\r
- "%20 = OpTypePointer Uniform %11\n"\r
- "%21 = OpConstant %11 1\n"\r
- "%1 = OpFunction %9 None %10\n"\r
- "%22 = OpLabel\n"\r
- "%23 = OpAccessChain %18 %2 %17\n"\r
- "%24 = OpLoad %14 %23\n"\r
- "%25 = OpAccessChain %20 %8 %13 %24\n"\r
- "%26 = OpLoad %11 %25\n"\r
- "%27 = OpAccessChain %20 %5 %13 %24\n"\r
- "OpStore %27 %26\n"\r
- "OpReturn\n"\r
- "OpFunctionEnd\n";\r
-\r
- auto code = compileSpirv(src.str().c_str());\r
-\r
const VkInstanceCreateInfo createInfo = {\r
VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, // sType\r
nullptr, // pNext\r
VK_ASSERT(Device::CreateComputeDevice(&driver, instance, &device));\r
ASSERT_TRUE(device.IsValid());\r
\r
- constexpr int NUM_ELEMENTS = 256;\r
-\r
- struct Buffers\r
- {\r
- uint32_t magic0;\r
- uint32_t in[NUM_ELEMENTS];\r
- uint32_t magic1;\r
- uint32_t out[NUM_ELEMENTS];\r
- uint32_t magic2;\r
- };\r
-\r
- constexpr uint32_t magic0 = 0x01234567;\r
- constexpr uint32_t magic1 = 0x89abcdef;\r
- constexpr uint32_t magic2 = 0xfedcba99;\r
+ // struct Buffers\r
+ // {\r
+ // uint32_t magic0;\r
+ // uint32_t in[NUM_ELEMENTS];\r
+ // uint32_t magic1;\r
+ // uint32_t out[NUM_ELEMENTS];\r
+ // uint32_t magic2;\r
+ // };\r
+ static constexpr uint32_t magic0 = 0x01234567;\r
+ static constexpr uint32_t magic1 = 0x89abcdef;\r
+ static constexpr uint32_t magic2 = 0xfedcba99;\r
+ size_t numElements = GetParam().numElements;\r
+ size_t magic0Offset = 0;\r
+ size_t inOffset = 1 + magic0Offset;\r
+ size_t magic1Offset = numElements + inOffset;\r
+ size_t outOffset = 1 + magic1Offset;\r
+ size_t magic2Offset = numElements + outOffset;\r
+ size_t buffersTotalElements = 1 + magic2Offset;\r
+ size_t buffersSize = sizeof(uint32_t) * buffersTotalElements;\r
\r
VkDeviceMemory memory;\r
- VK_ASSERT(device.AllocateMemory(sizeof(Buffers),\r
+ VK_ASSERT(device.AllocateMemory(buffersSize,\r
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,\r
&memory));\r
\r
- Buffers* buffers;\r
- VK_ASSERT(device.MapMemory(memory, 0, sizeof(Buffers), 0, (void**)&buffers));\r
-\r
- memset(buffers, 0, sizeof(Buffers));\r
+ uint32_t* buffers;\r
+ VK_ASSERT(device.MapMemory(memory, 0, buffersSize, 0, (void**)&buffers));\r
\r
- buffers->magic0 = magic0;\r
- buffers->magic1 = magic1;\r
- buffers->magic2 = magic2;\r
+ buffers[magic0Offset] = magic0;\r
+ buffers[magic1Offset] = magic1;\r
+ buffers[magic2Offset] = magic2;\r
\r
- for(int i = 0; i < NUM_ELEMENTS; i++)\r
+ for(size_t i = 0; i < numElements; i++)\r
{\r
- buffers->in[i] = (uint32_t)i;\r
+ buffers[inOffset + i] = input(i);\r
}\r
\r
device.UnmapMemory(memory);\r
buffers = nullptr;\r
\r
VkBuffer bufferIn;\r
- VK_ASSERT(device.CreateStorageBuffer(memory, sizeof(Buffers::in), offsetof(Buffers, in), &bufferIn));\r
+ VK_ASSERT(device.CreateStorageBuffer(memory,\r
+ sizeof(uint32_t) * numElements,\r
+ sizeof(uint32_t) * inOffset,\r
+ &bufferIn));\r
\r
VkBuffer bufferOut;\r
- VK_ASSERT(device.CreateStorageBuffer(memory, sizeof(Buffers::out), offsetof(Buffers, out), &bufferOut));\r
+ VK_ASSERT(device.CreateStorageBuffer(memory,\r
+ sizeof(uint32_t) * numElements,\r
+ sizeof(uint32_t) * outOffset,\r
+ &bufferOut));\r
\r
VkShaderModule shaderModule;\r
VK_ASSERT(device.CreateShaderModule(code, &shaderModule));\r
driver.vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, &descriptorSet,\r
0, nullptr);\r
\r
- driver.vkCmdDispatch(commandBuffer, NUM_ELEMENTS / params.localSizeX, 1, 1);\r
+ driver.vkCmdDispatch(commandBuffer, numElements / GetParam().localSizeX, 1, 1);\r
\r
VK_ASSERT(driver.vkEndCommandBuffer(commandBuffer));\r
\r
VK_ASSERT(device.QueueSubmitAndWait(commandBuffer));\r
\r
- VK_ASSERT(device.MapMemory(memory, 0, sizeof(Buffers), 0, (void**)&buffers));\r
+ VK_ASSERT(device.MapMemory(memory, 0, buffersSize, 0, (void**)&buffers));\r
\r
- for (int i = 0; i < NUM_ELEMENTS; ++i)\r
+ for (size_t i = 0; i < numElements; ++i)\r
{\r
- EXPECT_EQ(buffers->in[i], buffers->out[i]) << "Unexpected output at " << i;\r
+ auto got = buffers[i + outOffset];\r
+ EXPECT_EQ(expected(i), got) << "Unexpected output at " << i;\r
}\r
\r
// Check for writes outside of bounds.\r
- EXPECT_EQ(buffers->magic0, magic0);\r
- EXPECT_EQ(buffers->magic1, magic1);\r
- EXPECT_EQ(buffers->magic2, magic2);\r
+ EXPECT_EQ(buffers[magic0Offset], magic0);\r
+ EXPECT_EQ(buffers[magic1Offset], magic1);\r
+ EXPECT_EQ(buffers[magic2Offset], magic2);\r
\r
device.UnmapMemory(memory);\r
buffers = nullptr;\r
}\r
+\r
+INSTANTIATE_TEST_CASE_P(ComputeParams, SwiftShaderVulkanBufferToBufferComputeTest, testing::Values(\r
+ ComputeParams{512, 1, 1, 1},\r
+ ComputeParams{512, 2, 1, 1},\r
+ ComputeParams{512, 4, 1, 1},\r
+ ComputeParams{512, 8, 1, 1},\r
+ ComputeParams{512, 16, 1, 1},\r
+ ComputeParams{512, 32, 1, 1},\r
+\r
+ // Non-multiple of SIMD-lane.\r
+ ComputeParams{3, 1, 1, 1},\r
+ ComputeParams{2, 1, 1, 1}\r
+));\r
+\r
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, Memcpy)\r
+{\r
+ std::stringstream src;\r
+ src <<\r
+ "OpCapability Shader\n"\r
+ "OpMemoryModel Logical GLSL450\n"\r
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"\r
+ "OpExecutionMode %1 LocalSize " <<\r
+ GetParam().localSizeX << " " <<\r
+ GetParam().localSizeY << " " <<\r
+ GetParam().localSizeZ << "\n" <<\r
+ "OpDecorate %3 ArrayStride 4\n"\r
+ "OpMemberDecorate %4 0 Offset 0\n"\r
+ "OpDecorate %4 BufferBlock\n"\r
+ "OpDecorate %5 DescriptorSet 0\n"\r
+ "OpDecorate %5 Binding 1\n"\r
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"\r
+ "OpDecorate %6 DescriptorSet 0\n"\r
+ "OpDecorate %6 Binding 0\n"\r
+ "%7 = OpTypeVoid\n"\r
+ "%8 = OpTypeFunction %7\n" // void()\r
+ "%9 = OpTypeInt 32 1\n" // int32\r
+ "%10 = OpTypeInt 32 0\n" // uint32\r
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]\r
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }\r
+ "%11 = OpTypePointer Uniform %4\n" // struct{ int32[] }*\r
+ "%5 = OpVariable %11 Uniform\n" // struct{ int32[] }* in\r
+ "%12 = OpConstant %9 0\n" // int32(0)\r
+ "%13 = OpConstant %10 0\n" // uint32(0)\r
+ "%14 = OpTypeVector %10 3\n" // vec4<int32>\r
+ "%15 = OpTypePointer Input %14\n" // vec4<int32>*\r
+ "%2 = OpVariable %15 Input\n" // gl_GlobalInvocationId\r
+ "%16 = OpTypePointer Input %10\n" // uint32*\r
+ "%6 = OpVariable %11 Uniform\n" // struct{ int32[] }* out\r
+ "%17 = OpTypePointer Uniform %9\n" // int32*\r
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --\r
+ "%18 = OpLabel\n"\r
+ "%19 = OpAccessChain %16 %2 %13\n" // &gl_GlobalInvocationId.x\r
+ "%20 = OpLoad %10 %19\n" // gl_GlobalInvocationId.x\r
+ "%21 = OpAccessChain %17 %6 %12 %20\n" // &in.arr[gl_GlobalInvocationId.x]\r
+ "%22 = OpLoad %9 %21\n" // out.arr[gl_GlobalInvocationId.x]\r
+ "%23 = OpAccessChain %17 %5 %12 %20\n" // &out.arr[gl_GlobalInvocationId.x]\r
+ "OpStore %23 %22\n" // out.arr[gl_GlobalInvocationId.x] = in[gl_GlobalInvocationId.x]\r
+ "OpReturn\n"\r
+ "OpFunctionEnd\n";\r
+\r
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });\r
+}\r
+\r
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchSimple)\r
+{\r
+ std::stringstream src;\r
+ src <<\r
+ "OpCapability Shader\n"\r
+ "OpMemoryModel Logical GLSL450\n"\r
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"\r
+ "OpExecutionMode %1 LocalSize " <<\r
+ GetParam().localSizeX << " " <<\r
+ GetParam().localSizeY << " " <<\r
+ GetParam().localSizeZ << "\n" <<\r
+ "OpDecorate %3 ArrayStride 4\n"\r
+ "OpMemberDecorate %4 0 Offset 0\n"\r
+ "OpDecorate %4 BufferBlock\n"\r
+ "OpDecorate %5 DescriptorSet 0\n"\r
+ "OpDecorate %5 Binding 1\n"\r
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"\r
+ "OpDecorate %6 DescriptorSet 0\n"\r
+ "OpDecorate %6 Binding 0\n"\r
+ "%7 = OpTypeVoid\n"\r
+ "%8 = OpTypeFunction %7\n" // void()\r
+ "%9 = OpTypeInt 32 1\n" // int32\r
+ "%10 = OpTypeInt 32 0\n" // uint32\r
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]\r
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }\r
+ "%11 = OpTypePointer Uniform %4\n" // struct{ int32[] }*\r
+ "%5 = OpVariable %11 Uniform\n" // struct{ int32[] }* in\r
+ "%12 = OpConstant %9 0\n" // int32(0)\r
+ "%13 = OpConstant %10 0\n" // uint32(0)\r
+ "%14 = OpTypeVector %10 3\n" // vec4<int32>\r
+ "%15 = OpTypePointer Input %14\n" // vec4<int32>*\r
+ "%2 = OpVariable %15 Input\n" // gl_GlobalInvocationId\r
+ "%16 = OpTypePointer Input %10\n" // uint32*\r
+ "%6 = OpVariable %11 Uniform\n" // struct{ int32[] }* out\r
+ "%17 = OpTypePointer Uniform %9\n" // int32*\r
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --\r
+ "%18 = OpLabel\n"\r
+ "%19 = OpAccessChain %16 %2 %13\n" // &gl_GlobalInvocationId.x\r
+ "%20 = OpLoad %10 %19\n" // gl_GlobalInvocationId.x\r
+ "%21 = OpAccessChain %17 %6 %12 %20\n" // &in.arr[gl_GlobalInvocationId.x]\r
+ "%22 = OpLoad %9 %21\n" // in.arr[gl_GlobalInvocationId.x]\r
+ "%23 = OpAccessChain %17 %5 %12 %20\n" // &out.arr[gl_GlobalInvocationId.x]\r
+ // Start of branch logic\r
+ // %22 = in value\r
+ "OpBranch %24\n"\r
+ "%24 = OpLabel\n"\r
+ "OpBranch %25\n"\r
+ "%25 = OpLabel\n"\r
+ "OpBranch %26\n"\r
+ "%26 = OpLabel\n"\r
+ // %22 = out value\r
+ // End of branch logic\r
+ "OpStore %23 %22\n"\r
+ "OpReturn\n"\r
+ "OpFunctionEnd\n";\r
+\r
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i; });\r
+}\r
+\r
+TEST_P(SwiftShaderVulkanBufferToBufferComputeTest, BranchDeclareSSA)\r
+{\r
+ std::stringstream src;\r
+ src <<\r
+ "OpCapability Shader\n"\r
+ "OpMemoryModel Logical GLSL450\n"\r
+ "OpEntryPoint GLCompute %1 \"main\" %2\n"\r
+ "OpExecutionMode %1 LocalSize " <<\r
+ GetParam().localSizeX << " " <<\r
+ GetParam().localSizeY << " " <<\r
+ GetParam().localSizeZ << "\n" <<\r
+ "OpDecorate %3 ArrayStride 4\n"\r
+ "OpMemberDecorate %4 0 Offset 0\n"\r
+ "OpDecorate %4 BufferBlock\n"\r
+ "OpDecorate %5 DescriptorSet 0\n"\r
+ "OpDecorate %5 Binding 1\n"\r
+ "OpDecorate %2 BuiltIn GlobalInvocationId\n"\r
+ "OpDecorate %6 DescriptorSet 0\n"\r
+ "OpDecorate %6 Binding 0\n"\r
+ "%7 = OpTypeVoid\n"\r
+ "%8 = OpTypeFunction %7\n" // void()\r
+ "%9 = OpTypeInt 32 1\n" // int32\r
+ "%10 = OpTypeInt 32 0\n" // uint32\r
+ "%3 = OpTypeRuntimeArray %9\n" // int32[]\r
+ "%4 = OpTypeStruct %3\n" // struct{ int32[] }\r
+ "%11 = OpTypePointer Uniform %4\n" // struct{ int32[] }*\r
+ "%5 = OpVariable %11 Uniform\n" // struct{ int32[] }* in\r
+ "%12 = OpConstant %9 0\n" // int32(0)\r
+ "%13 = OpConstant %10 0\n" // uint32(0)\r
+ "%14 = OpTypeVector %10 3\n" // vec4<int32>\r
+ "%15 = OpTypePointer Input %14\n" // vec4<int32>*\r
+ "%2 = OpVariable %15 Input\n" // gl_GlobalInvocationId\r
+ "%16 = OpTypePointer Input %10\n" // uint32*\r
+ "%6 = OpVariable %11 Uniform\n" // struct{ int32[] }* out\r
+ "%17 = OpTypePointer Uniform %9\n" // int32*\r
+ "%1 = OpFunction %7 None %8\n" // -- Function begin --\r
+ "%18 = OpLabel\n"\r
+ "%19 = OpAccessChain %16 %2 %13\n" // &gl_GlobalInvocationId.x\r
+ "%20 = OpLoad %10 %19\n" // gl_GlobalInvocationId.x\r
+ "%21 = OpAccessChain %17 %6 %12 %20\n" // &in.arr[gl_GlobalInvocationId.x]\r
+ "%22 = OpLoad %9 %21\n" // in.arr[gl_GlobalInvocationId.x]\r
+ "%23 = OpAccessChain %17 %5 %12 %20\n" // &out.arr[gl_GlobalInvocationId.x]\r
+ // Start of branch logic\r
+ // %22 = in value\r
+ "OpBranch %24\n"\r
+ "%24 = OpLabel\n"\r
+ "%25 = OpIAdd %9 %22 %22\n" // %25 = in*2\r
+ "OpBranch %26\n"\r
+ "%26 = OpLabel\n"\r
+ "OpBranch %27\n"\r
+ "%27 = OpLabel\n"\r
+ // %25 = out value\r
+ // End of branch logic\r
+ "OpStore %23 %25\n" // use SSA value from previous block\r
+ "OpReturn\n"\r
+ "OpFunctionEnd\n";\r
+\r
+ test(src.str(), [](uint32_t i) { return i; }, [](uint32_t i) { return i * 2; });\r
+}
\ No newline at end of file