// limitations under the License.
#include <Pipeline/SpirvShader.hpp>
+#include <spirv-tools/optimizer.hpp>
#include "VkPipeline.hpp"
#include "VkShaderModule.hpp"
void GraphicsPipeline::compileShaders(const VkAllocationCallbacks* pAllocator, const VkGraphicsPipelineCreateInfo* pCreateInfo)
{
- for (auto pStage = pCreateInfo->pStages; pStage != pCreateInfo->pStages + pCreateInfo->stageCount; pStage++) {
+ for (auto pStage = pCreateInfo->pStages; pStage != pCreateInfo->pStages + pCreateInfo->stageCount; pStage++)
+ {
auto module = Cast(pStage->module);
- // TODO: apply prep passes using SPIRV-Opt here.
- // - Apply and freeze specializations, etc.
auto code = module->getCode();
+ spvtools::Optimizer opt{SPV_ENV_VULKAN_1_1};
+ opt.RegisterPass(spvtools::CreateInlineExhaustivePass());
+
+ // If the pipeline uses specialization, apply the specializations before freezing
+ if (pStage->pSpecializationInfo)
+ {
+ std::unordered_map<uint32_t, std::vector<uint32_t>> specializations;
+ for (auto i = 0u; i < pStage->pSpecializationInfo->mapEntryCount; ++i)
+ {
+ auto const &e = pStage->pSpecializationInfo->pMapEntries[i];
+ auto value_ptr =
+ static_cast<uint32_t const *>(pStage->pSpecializationInfo->pData) + e.offset / sizeof(uint32_t);
+ specializations.emplace(e.constantID,
+ std::vector<uint32_t>{value_ptr, value_ptr + e.size / sizeof(uint32_t)});
+ }
+ opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
+ }
+ // Freeze specialization constants into normal constants, and propagate through
+ opt.RegisterPass(spvtools::CreateFreezeSpecConstantValuePass());
+ opt.RegisterPass(spvtools::CreateFoldSpecConstantOpAndCompositePass());
+
+ std::vector<uint32_t> postOptCode;
+ opt.Run(code.data(), code.size(), &postOptCode);
- // TODO: pass in additional information here:
- // - any NOS from pCreateInfo which we'll actually need
- auto spirvShader = new sw::SpirvShader{code};
+ // TODO: also pass in any pipeline state which will affect shader compilation
+ auto spirvShader = new sw::SpirvShader{postOptCode};
- switch (pStage->stage) {
- case VK_SHADER_STAGE_VERTEX_BIT:
- context.vertexShader = vertexShader = spirvShader;
- break;
+ switch (pStage->stage)
+ {
+ case VK_SHADER_STAGE_VERTEX_BIT:
+ context.vertexShader = vertexShader = spirvShader;
+ break;
- case VK_SHADER_STAGE_FRAGMENT_BIT:
- context.pixelShader = fragmentShader = spirvShader;
- break;
+ case VK_SHADER_STAGE_FRAGMENT_BIT:
+ context.pixelShader = fragmentShader = spirvShader;
+ break;
- default:
- UNIMPLEMENTED("Unsupported stage");
+ default:
+ UNIMPLEMENTED("Unsupported stage");
}
}
}