1 // Copyright 2018 The SwiftShader Authors. All Rights Reserved.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "VkPipeline.hpp"
16 #include "VkShaderModule.hpp"
17 #include "Pipeline/SpirvShader.hpp"
19 #include "spirv-tools/optimizer.hpp"
24 sw::DrawType Convert(VkPrimitiveTopology topology)
28 case VK_PRIMITIVE_TOPOLOGY_POINT_LIST:
29 return sw::DRAW_POINTLIST;
30 case VK_PRIMITIVE_TOPOLOGY_LINE_LIST:
31 return sw::DRAW_LINELIST;
32 case VK_PRIMITIVE_TOPOLOGY_LINE_STRIP:
33 return sw::DRAW_LINESTRIP;
34 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST:
35 return sw::DRAW_TRIANGLELIST;
36 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP:
37 return sw::DRAW_TRIANGLESTRIP;
38 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_FAN:
39 return sw::DRAW_TRIANGLEFAN;
40 case VK_PRIMITIVE_TOPOLOGY_LINE_LIST_WITH_ADJACENCY:
41 case VK_PRIMITIVE_TOPOLOGY_LINE_STRIP_WITH_ADJACENCY:
42 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST_WITH_ADJACENCY:
43 case VK_PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP_WITH_ADJACENCY:
44 // geometry shader specific
47 case VK_PRIMITIVE_TOPOLOGY_PATCH_LIST:
48 // tesselation shader specific
55 return sw::DRAW_TRIANGLELIST;
58 sw::Rect Convert(const VkRect2D& rect)
60 return sw::Rect(rect.offset.x, rect.offset.y, rect.offset.x + rect.extent.width, rect.offset.y + rect.extent.height);
63 sw::StreamType getStreamType(VkFormat format)
67 case VK_FORMAT_R8_UNORM:
68 case VK_FORMAT_R8G8_UNORM:
69 case VK_FORMAT_R8G8B8A8_UNORM:
70 case VK_FORMAT_R8_UINT:
71 case VK_FORMAT_R8G8_UINT:
72 case VK_FORMAT_R8G8B8A8_UINT:
73 case VK_FORMAT_B8G8R8A8_UNORM:
74 case VK_FORMAT_A8B8G8R8_UNORM_PACK32:
75 case VK_FORMAT_A8B8G8R8_UINT_PACK32:
76 return sw::STREAMTYPE_BYTE;
77 case VK_FORMAT_R8_SNORM:
78 case VK_FORMAT_R8_SINT:
79 case VK_FORMAT_R8G8_SNORM:
80 case VK_FORMAT_R8G8_SINT:
81 case VK_FORMAT_R8G8B8A8_SNORM:
82 case VK_FORMAT_R8G8B8A8_SINT:
83 case VK_FORMAT_A8B8G8R8_SNORM_PACK32:
84 case VK_FORMAT_A8B8G8R8_SINT_PACK32:
85 return sw::STREAMTYPE_SBYTE;
86 case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
87 return sw::STREAMTYPE_2_10_10_10_UINT;
88 case VK_FORMAT_R16_UNORM:
89 case VK_FORMAT_R16_UINT:
90 case VK_FORMAT_R16G16_UNORM:
91 case VK_FORMAT_R16G16_UINT:
92 case VK_FORMAT_R16G16B16A16_UNORM:
93 case VK_FORMAT_R16G16B16A16_UINT:
94 return sw::STREAMTYPE_USHORT;
95 case VK_FORMAT_R16_SNORM:
96 case VK_FORMAT_R16_SINT:
97 case VK_FORMAT_R16G16_SNORM:
98 case VK_FORMAT_R16G16_SINT:
99 case VK_FORMAT_R16G16B16A16_SNORM:
100 case VK_FORMAT_R16G16B16A16_SINT:
101 return sw::STREAMTYPE_SHORT;
102 case VK_FORMAT_R16_SFLOAT:
103 case VK_FORMAT_R16G16_SFLOAT:
104 case VK_FORMAT_R16G16B16A16_SFLOAT:
105 return sw::STREAMTYPE_HALF;
106 case VK_FORMAT_R32_UINT:
107 case VK_FORMAT_R32G32_UINT:
108 case VK_FORMAT_R32G32B32_UINT:
109 case VK_FORMAT_R32G32B32A32_UINT:
110 return sw::STREAMTYPE_UINT;
111 case VK_FORMAT_R32_SINT:
112 case VK_FORMAT_R32G32_SINT:
113 case VK_FORMAT_R32G32B32_SINT:
114 case VK_FORMAT_R32G32B32A32_SINT:
115 return sw::STREAMTYPE_INT;
116 case VK_FORMAT_R32_SFLOAT:
117 case VK_FORMAT_R32G32_SFLOAT:
118 case VK_FORMAT_R32G32B32_SFLOAT:
119 case VK_FORMAT_R32G32B32A32_SFLOAT:
120 return sw::STREAMTYPE_FLOAT;
125 return sw::STREAMTYPE_BYTE;
128 uint32_t getNumberOfChannels(VkFormat format)
132 case VK_FORMAT_R8_UNORM:
133 case VK_FORMAT_R8_SNORM:
134 case VK_FORMAT_R8_UINT:
135 case VK_FORMAT_R8_SINT:
136 case VK_FORMAT_R16_UNORM:
137 case VK_FORMAT_R16_SNORM:
138 case VK_FORMAT_R16_UINT:
139 case VK_FORMAT_R16_SINT:
140 case VK_FORMAT_R16_SFLOAT:
141 case VK_FORMAT_R32_UINT:
142 case VK_FORMAT_R32_SINT:
143 case VK_FORMAT_R32_SFLOAT:
145 case VK_FORMAT_R8G8_UNORM:
146 case VK_FORMAT_R8G8_SNORM:
147 case VK_FORMAT_R8G8_UINT:
148 case VK_FORMAT_R8G8_SINT:
149 case VK_FORMAT_R16G16_UNORM:
150 case VK_FORMAT_R16G16_SNORM:
151 case VK_FORMAT_R16G16_UINT:
152 case VK_FORMAT_R16G16_SINT:
153 case VK_FORMAT_R16G16_SFLOAT:
154 case VK_FORMAT_R32G32_UINT:
155 case VK_FORMAT_R32G32_SINT:
156 case VK_FORMAT_R32G32_SFLOAT:
158 case VK_FORMAT_R32G32B32_UINT:
159 case VK_FORMAT_R32G32B32_SINT:
160 case VK_FORMAT_R32G32B32_SFLOAT:
162 case VK_FORMAT_R8G8B8A8_UNORM:
163 case VK_FORMAT_R8G8B8A8_SNORM:
164 case VK_FORMAT_R8G8B8A8_UINT:
165 case VK_FORMAT_R8G8B8A8_SINT:
166 case VK_FORMAT_B8G8R8A8_UNORM:
167 case VK_FORMAT_A8B8G8R8_UNORM_PACK32:
168 case VK_FORMAT_A8B8G8R8_SNORM_PACK32:
169 case VK_FORMAT_A8B8G8R8_UINT_PACK32:
170 case VK_FORMAT_A8B8G8R8_SINT_PACK32:
171 case VK_FORMAT_A2B10G10R10_UNORM_PACK32:
172 case VK_FORMAT_R16G16B16A16_UNORM:
173 case VK_FORMAT_R16G16B16A16_SNORM:
174 case VK_FORMAT_R16G16B16A16_UINT:
175 case VK_FORMAT_R16G16B16A16_SINT:
176 case VK_FORMAT_R16G16B16A16_SFLOAT:
177 case VK_FORMAT_R32G32B32A32_UINT:
178 case VK_FORMAT_R32G32B32A32_SINT:
179 case VK_FORMAT_R32G32B32A32_SFLOAT:
193 GraphicsPipeline::GraphicsPipeline(const VkGraphicsPipelineCreateInfo* pCreateInfo, void* mem)
195 if((pCreateInfo->flags != 0) ||
196 (pCreateInfo->stageCount != 2) ||
197 (pCreateInfo->pTessellationState != nullptr) ||
198 (pCreateInfo->pDynamicState != nullptr) ||
199 (pCreateInfo->subpass != 0) ||
200 (pCreateInfo->basePipelineHandle != VK_NULL_HANDLE) ||
201 (pCreateInfo->basePipelineIndex != 0))
206 const VkPipelineShaderStageCreateInfo& vertexStage = pCreateInfo->pStages[0];
207 if((vertexStage.stage != VK_SHADER_STAGE_VERTEX_BIT) ||
208 (vertexStage.flags != 0) ||
209 !((vertexStage.pSpecializationInfo == nullptr) ||
210 ((vertexStage.pSpecializationInfo->mapEntryCount == 0) &&
211 (vertexStage.pSpecializationInfo->dataSize == 0))))
216 const VkPipelineShaderStageCreateInfo& fragmentStage = pCreateInfo->pStages[1];
217 if((fragmentStage.stage != VK_SHADER_STAGE_FRAGMENT_BIT) ||
218 (fragmentStage.flags != 0) ||
219 !((fragmentStage.pSpecializationInfo == nullptr) ||
220 ((fragmentStage.pSpecializationInfo->mapEntryCount == 0) &&
221 (fragmentStage.pSpecializationInfo->dataSize == 0))))
226 const VkPipelineVertexInputStateCreateInfo* vertexInputState = pCreateInfo->pVertexInputState;
227 if(vertexInputState->flags != 0)
232 for(uint32_t i = 0; i < vertexInputState->vertexBindingDescriptionCount; i++)
234 const VkVertexInputBindingDescription* vertexBindingDescription = vertexInputState->pVertexBindingDescriptions;
235 context.input[vertexBindingDescription->binding].stride = vertexBindingDescription->stride;
236 if(vertexBindingDescription->inputRate != VK_VERTEX_INPUT_RATE_VERTEX)
242 for(uint32_t i = 0; i < vertexInputState->vertexAttributeDescriptionCount; i++)
244 const VkVertexInputAttributeDescription* vertexAttributeDescriptions = vertexInputState->pVertexAttributeDescriptions;
245 sw::Stream& input = context.input[vertexAttributeDescriptions->binding];
246 input.count = getNumberOfChannels(vertexAttributeDescriptions->format);
247 input.type = getStreamType(vertexAttributeDescriptions->format);
248 input.normalized = !sw::Surface::isNonNormalizedInteger(vertexAttributeDescriptions->format);
250 if(vertexAttributeDescriptions->location != vertexAttributeDescriptions->binding)
254 if(vertexAttributeDescriptions->offset != 0)
260 const VkPipelineInputAssemblyStateCreateInfo* assemblyState = pCreateInfo->pInputAssemblyState;
261 if((assemblyState->flags != 0) ||
262 (assemblyState->primitiveRestartEnable != 0))
267 context.drawType = Convert(assemblyState->topology);
269 const VkPipelineViewportStateCreateInfo* viewportState = pCreateInfo->pViewportState;
272 if((viewportState->flags != 0) ||
273 (viewportState->viewportCount != 1) ||
274 (viewportState->scissorCount != 1))
279 scissor = Convert(viewportState->pScissors[0]);
280 viewport = viewportState->pViewports[0];
283 const VkPipelineRasterizationStateCreateInfo* rasterizationState = pCreateInfo->pRasterizationState;
284 if((rasterizationState->flags != 0) ||
285 (rasterizationState->depthClampEnable != 0) ||
286 (rasterizationState->polygonMode != VK_POLYGON_MODE_FILL))
291 context.rasterizerDiscard = rasterizationState->rasterizerDiscardEnable;
292 context.cullMode = rasterizationState->cullMode;
293 context.frontFacingCCW = rasterizationState->frontFace == VK_FRONT_FACE_COUNTER_CLOCKWISE;
294 context.depthBias = (rasterizationState->depthBiasEnable ? rasterizationState->depthBiasConstantFactor : 0.0f);
295 context.slopeDepthBias = (rasterizationState->depthBiasEnable ? rasterizationState->depthBiasSlopeFactor : 0.0f);
297 const VkPipelineMultisampleStateCreateInfo* multisampleState = pCreateInfo->pMultisampleState;
300 if((multisampleState->flags != 0) ||
301 (multisampleState->rasterizationSamples != VK_SAMPLE_COUNT_1_BIT) ||
302 (multisampleState->sampleShadingEnable != 0) ||
303 !((multisampleState->pSampleMask == nullptr) ||
304 (*(multisampleState->pSampleMask) == 0xFFFFFFFFu)) ||
305 (multisampleState->alphaToCoverageEnable != 0) ||
306 (multisampleState->alphaToOneEnable != 0))
312 const VkPipelineDepthStencilStateCreateInfo* depthStencilState = pCreateInfo->pDepthStencilState;
313 if(depthStencilState)
315 if((depthStencilState->flags != 0) ||
316 (depthStencilState->depthBoundsTestEnable != 0) ||
317 (depthStencilState->minDepthBounds != 0.0f) ||
318 (depthStencilState->maxDepthBounds != 1.0f))
323 context.depthBufferEnable = depthStencilState->depthTestEnable;
324 context.depthWriteEnable = depthStencilState->depthWriteEnable;
325 context.depthCompareMode = depthStencilState->depthCompareOp;
327 context.stencilEnable = context.twoSidedStencil = depthStencilState->stencilTestEnable;
328 if(context.stencilEnable)
330 context.stencilMask = depthStencilState->front.compareMask;
331 context.stencilCompareMode = depthStencilState->front.compareOp;
332 context.stencilZFailOperation = depthStencilState->front.depthFailOp;
333 context.stencilFailOperation = depthStencilState->front.failOp;
334 context.stencilPassOperation = depthStencilState->front.passOp;
335 context.stencilReference = depthStencilState->front.reference;
336 context.stencilWriteMask = depthStencilState->front.writeMask;
338 context.stencilMaskCCW = depthStencilState->back.compareMask;
339 context.stencilCompareModeCCW = depthStencilState->back.compareOp;
340 context.stencilZFailOperationCCW = depthStencilState->back.depthFailOp;
341 context.stencilFailOperationCCW = depthStencilState->back.failOp;
342 context.stencilPassOperationCCW = depthStencilState->back.passOp;
343 context.stencilReferenceCCW = depthStencilState->back.reference;
344 context.stencilWriteMaskCCW = depthStencilState->back.writeMask;
348 const VkPipelineColorBlendStateCreateInfo* colorBlendState = pCreateInfo->pColorBlendState;
351 if((colorBlendState->flags != 0) ||
352 ((colorBlendState->logicOpEnable != 0) &&
353 (colorBlendState->attachmentCount > 1)))
358 context.colorLogicOpEnabled = colorBlendState->logicOpEnable;
359 context.logicalOperation = colorBlendState->logicOp;
360 blendConstants.r = colorBlendState->blendConstants[0];
361 blendConstants.g = colorBlendState->blendConstants[1];
362 blendConstants.b = colorBlendState->blendConstants[2];
363 blendConstants.a = colorBlendState->blendConstants[3];
365 if(colorBlendState->attachmentCount == 1)
367 const VkPipelineColorBlendAttachmentState& attachment = colorBlendState->pAttachments[0];
368 if(attachment.colorWriteMask != (VK_COLOR_COMPONENT_R_BIT | VK_COLOR_COMPONENT_G_BIT | VK_COLOR_COMPONENT_B_BIT | VK_COLOR_COMPONENT_A_BIT))
373 context.alphaBlendEnable = attachment.blendEnable;
374 context.separateAlphaBlendEnable = (attachment.alphaBlendOp != attachment.colorBlendOp) ||
375 (attachment.dstAlphaBlendFactor != attachment.dstColorBlendFactor) ||
376 (attachment.srcAlphaBlendFactor != attachment.srcColorBlendFactor);
377 context.blendOperationStateAlpha = attachment.alphaBlendOp;
378 context.blendOperationState = attachment.colorBlendOp;
379 context.destBlendFactorStateAlpha = attachment.dstAlphaBlendFactor;
380 context.destBlendFactorState = attachment.dstColorBlendFactor;
381 context.sourceBlendFactorStateAlpha = attachment.srcAlphaBlendFactor;
382 context.sourceBlendFactorState = attachment.srcColorBlendFactor;
387 void GraphicsPipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator)
390 delete fragmentShader;
393 size_t GraphicsPipeline::ComputeRequiredAllocationSize(const VkGraphicsPipelineCreateInfo* pCreateInfo)
398 void GraphicsPipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkGraphicsPipelineCreateInfo* pCreateInfo)
400 for (auto pStage = pCreateInfo->pStages; pStage != pCreateInfo->pStages + pCreateInfo->stageCount; pStage++)
402 auto module = Cast(pStage->module);
404 auto code = module->getCode();
405 spvtools::Optimizer opt{SPV_ENV_VULKAN_1_1};
406 opt.RegisterPass(spvtools::CreateInlineExhaustivePass());
408 // If the pipeline uses specialization, apply the specializations before freezing
409 if (pStage->pSpecializationInfo)
411 std::unordered_map<uint32_t, std::vector<uint32_t>> specializations;
412 for (auto i = 0u; i < pStage->pSpecializationInfo->mapEntryCount; ++i)
414 auto const &e = pStage->pSpecializationInfo->pMapEntries[i];
416 static_cast<uint32_t const *>(pStage->pSpecializationInfo->pData) + e.offset / sizeof(uint32_t);
417 specializations.emplace(e.constantID,
418 std::vector<uint32_t>{value_ptr, value_ptr + e.size / sizeof(uint32_t)});
420 opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
422 // Freeze specialization constants into normal constants, and propagate through
423 opt.RegisterPass(spvtools::CreateFreezeSpecConstantValuePass());
424 opt.RegisterPass(spvtools::CreateFoldSpecConstantOpAndCompositePass());
426 std::vector<uint32_t> postOptCode;
427 opt.Run(code.data(), code.size(), &postOptCode);
429 // TODO: also pass in any pipeline state which will affect shader compilation
430 auto spirvShader = new sw::SpirvShader{postOptCode};
432 switch (pStage->stage)
434 case VK_SHADER_STAGE_VERTEX_BIT:
435 context.vertexShader = vertexShader = spirvShader;
438 case VK_SHADER_STAGE_FRAGMENT_BIT:
439 context.pixelShader = fragmentShader = spirvShader;
443 UNIMPLEMENTED("Unsupported stage");
448 uint32_t GraphicsPipeline::computePrimitiveCount(uint32_t vertexCount) const
450 switch(context.drawType)
452 case sw::DRAW_POINTLIST:
454 case sw::DRAW_LINELIST:
455 return vertexCount / 2;
456 case sw::DRAW_LINESTRIP:
457 return vertexCount - 1;
458 case sw::DRAW_TRIANGLELIST:
459 return vertexCount / 3;
460 case sw::DRAW_TRIANGLESTRIP:
461 return vertexCount - 2;
462 case sw::DRAW_TRIANGLEFAN:
463 return vertexCount - 2;
471 const sw::Context& GraphicsPipeline::getContext() const
476 const sw::Rect& GraphicsPipeline::getScissor() const
481 const VkViewport& GraphicsPipeline::getViewport() const
486 const sw::Color<float>& GraphicsPipeline::getBlendConstants() const
488 return blendConstants;
491 ComputePipeline::ComputePipeline(const VkComputePipelineCreateInfo* pCreateInfo, void* mem)
495 void ComputePipeline::destroyPipeline(const VkAllocationCallbacks* pAllocator)
499 size_t ComputePipeline::ComputeRequiredAllocationSize(const VkComputePipelineCreateInfo* pCreateInfo)