1 // Copyright 2018 The SwiftShader Authors. All Rights Reserved.
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include <spirv/unified1/spirv.hpp>
16 #include <spirv/unified1/GLSL.std.450.h>
17 #include "SpirvShader.hpp"
18 #include "System/Math.hpp"
19 #include "Vulkan/VkBuffer.hpp"
20 #include "Vulkan/VkDebug.hpp"
21 #include "Vulkan/VkPipelineLayout.hpp"
22 #include "Device/Config.hpp"
25 #undef Bool // b/127920555
30 rr::RValue<rr::Bool> AnyTrue(rr::RValue<sw::SIMD::Int> const &ints)
32 return rr::SignMask(ints) != 0;
35 rr::RValue<rr::Bool> AnyFalse(rr::RValue<sw::SIMD::Int> const &ints)
37 return rr::SignMask(~ints) != 0;
43 volatile int SpirvShader::serialCounter = 1; // Start at 1, 0 is invalid shader.
45 SpirvShader::SpirvShader(InsnStore const &insns)
46 : insns{insns}, inputs{MAX_INTERFACE_COMPONENTS},
47 outputs{MAX_INTERFACE_COMPONENTS},
48 serialID{serialCounter++}, modes{}
50 ASSERT(insns.size() > 0);
52 // Simplifying assumptions (to be satisfied by earlier transformations)
53 // - There is exactly one entrypoint in the module, and it's the one we want
54 // - The only input/output OpVariables present are those used by the entrypoint
56 Block::ID currentBlock;
57 InsnIterator blockStart;
59 for (auto insn : *this)
61 switch (insn.opcode())
63 case spv::OpExecutionMode:
64 ProcessExecutionMode(insn);
69 TypeOrObjectID targetId = insn.word(1);
70 auto decoration = static_cast<spv::Decoration>(insn.word(2));
71 decorations[targetId].Apply(
73 insn.wordCount() > 3 ? insn.word(3) : 0);
75 if (decoration == spv::DecorationCentroid)
76 modes.NeedsCentroid = true;
80 case spv::OpMemberDecorate:
82 Type::ID targetId = insn.word(1);
83 auto memberIndex = insn.word(2);
84 auto &d = memberDecorations[targetId];
85 if (memberIndex >= d.size())
86 d.resize(memberIndex + 1); // on demand; exact size would require another pass...
87 auto decoration = static_cast<spv::Decoration>(insn.word(3));
90 insn.wordCount() > 4 ? insn.word(4) : 0);
92 if (decoration == spv::DecorationCentroid)
93 modes.NeedsCentroid = true;
97 case spv::OpDecorationGroup:
98 // Nothing to do here. We don't need to record the definition of the group; we'll just have
99 // the bundle of decorations float around. If we were to ever walk the decorations directly,
100 // we might think about introducing this as a real Object.
103 case spv::OpGroupDecorate:
105 auto const &srcDecorations = decorations[insn.word(1)];
106 for (auto i = 2u; i < insn.wordCount(); i++)
108 // remaining operands are targets to apply the group to.
109 decorations[insn.word(i)].Apply(srcDecorations);
114 case spv::OpGroupMemberDecorate:
116 auto const &srcDecorations = decorations[insn.word(1)];
117 for (auto i = 2u; i < insn.wordCount(); i += 2)
119 // remaining operands are pairs of <id>, literal for members to apply to.
120 auto &d = memberDecorations[insn.word(i)];
121 auto memberIndex = insn.word(i + 1);
122 if (memberIndex >= d.size())
123 d.resize(memberIndex + 1); // on demand resize, see above...
124 d[memberIndex].Apply(srcDecorations);
131 ASSERT(currentBlock.value() == 0);
132 currentBlock = Block::ID(insn.word(1));
137 // Branch Instructions (subset of Termination Instructions):
139 case spv::OpBranchConditional:
144 // Termination instruction:
146 case spv::OpUnreachable:
148 ASSERT(currentBlock.value() != 0);
149 auto blockEnd = insn; blockEnd++;
150 blocks[currentBlock] = Block(blockStart, blockEnd);
151 currentBlock = Block::ID(0);
153 if (insn.opcode() == spv::OpKill)
155 modes.ContainsKill = true;
160 case spv::OpLoopMerge:
161 case spv::OpSelectionMerge:
162 break; // Nothing to do in analysis pass.
164 case spv::OpTypeVoid:
165 case spv::OpTypeBool:
167 case spv::OpTypeFloat:
168 case spv::OpTypeVector:
169 case spv::OpTypeMatrix:
170 case spv::OpTypeImage:
171 case spv::OpTypeSampler:
172 case spv::OpTypeSampledImage:
173 case spv::OpTypeArray:
174 case spv::OpTypeRuntimeArray:
175 case spv::OpTypeStruct:
176 case spv::OpTypePointer:
177 case spv::OpTypeFunction:
181 case spv::OpVariable:
183 Type::ID typeId = insn.word(1);
184 Object::ID resultId = insn.word(2);
185 auto storageClass = static_cast<spv::StorageClass>(insn.word(3));
186 if (insn.wordCount() > 4)
187 UNIMPLEMENTED("Variable initializers not yet supported");
189 auto &object = defs[resultId];
190 object.kind = Object::Kind::Variable;
191 object.definition = insn;
192 object.type = typeId;
193 object.pointerBase = insn.word(2); // base is itself
195 ASSERT(getType(typeId).storageClass == storageClass);
197 switch (storageClass)
199 case spv::StorageClassInput:
200 case spv::StorageClassOutput:
201 ProcessInterfaceVariable(object);
203 case spv::StorageClassUniform:
204 case spv::StorageClassStorageBuffer:
205 case spv::StorageClassPushConstant:
206 object.kind = Object::Kind::PhysicalPointer;
209 case spv::StorageClassPrivate:
210 case spv::StorageClassFunction:
211 break; // Correctly handled.
213 case spv::StorageClassUniformConstant:
214 case spv::StorageClassWorkgroup:
215 case spv::StorageClassCrossWorkgroup:
216 case spv::StorageClassGeneric:
217 case spv::StorageClassAtomicCounter:
218 case spv::StorageClassImage:
219 UNIMPLEMENTED("StorageClass %d not yet implemented", (int)storageClass);
223 UNREACHABLE("Unexpected StorageClass %d", storageClass); // See Appendix A of the Vulkan spec.
229 case spv::OpConstant:
230 CreateConstant(insn).constantValue[0] = insn.word(3);
232 case spv::OpConstantFalse:
233 CreateConstant(insn).constantValue[0] = 0; // represent boolean false as zero
235 case spv::OpConstantTrue:
236 CreateConstant(insn).constantValue[0] = ~0u; // represent boolean true as all bits set
238 case spv::OpConstantNull:
241 // TODO: consider a real LLVM-level undef. For now, zero is a perfectly good value.
242 // OpConstantNull forms a constant of arbitrary type, all zeros.
243 auto &object = CreateConstant(insn);
244 auto &objectTy = getType(object.type);
245 for (auto i = 0u; i < objectTy.sizeInComponents; i++)
247 object.constantValue[i] = 0;
251 case spv::OpConstantComposite:
253 auto &object = CreateConstant(insn);
255 for (auto i = 0u; i < insn.wordCount() - 3; i++)
257 auto &constituent = getObject(insn.word(i + 3));
258 auto &constituentTy = getType(constituent.type);
259 for (auto j = 0u; j < constituentTy.sizeInComponents; j++)
260 object.constantValue[offset++] = constituent.constantValue[j];
263 auto objectId = Object::ID(insn.word(2));
264 auto decorationsIt = decorations.find(objectId);
265 if (decorationsIt != decorations.end() &&
266 decorationsIt->second.BuiltIn == spv::BuiltInWorkgroupSize)
268 // https://www.khronos.org/registry/vulkan/specs/1.1/html/vkspec.html#interfaces-builtin-variables :
269 // Decorating an object with the WorkgroupSize built-in
270 // decoration will make that object contain the dimensions
271 // of a local workgroup. If an object is decorated with the
272 // WorkgroupSize decoration, this must take precedence over
273 // any execution mode set for LocalSize.
274 // The object decorated with WorkgroupSize must be declared
275 // as a three-component vector of 32-bit integers.
276 ASSERT(getType(object.type).sizeInComponents == 3);
277 modes.WorkgroupSizeX = object.constantValue[0];
278 modes.WorkgroupSizeY = object.constantValue[1];
279 modes.WorkgroupSizeZ = object.constantValue[2];
284 case spv::OpCapability:
285 break; // Various capabilities will be declared, but none affect our code generation at this point.
286 case spv::OpMemoryModel:
287 break; // Memory model does not affect our code generation until we decide to do Vulkan Memory Model support.
289 case spv::OpEntryPoint:
291 case spv::OpFunction:
292 ASSERT(mainBlockId.value() == 0); // Multiple functions found
293 // Scan forward to find the function's label.
294 for (auto it = insn; it != end() && mainBlockId.value() == 0; it++)
298 case spv::OpFunction:
299 case spv::OpFunctionParameter:
302 mainBlockId = Block::ID(it.word(1));
305 WARN("Unexpected opcode '%s' following OpFunction", OpcodeName(it.opcode()).c_str());
308 ASSERT(mainBlockId.value() != 0); // Function's OpLabel not found
310 case spv::OpFunctionEnd:
311 // Due to preprocessing, the entrypoint and its function provide no value.
313 case spv::OpExtInstImport:
314 // We will only support the GLSL 450 extended instruction set, so no point in tracking the ID we assign it.
315 // Valid shaders will not attempt to import any other instruction sets.
316 if (0 != strcmp("GLSL.std.450", reinterpret_cast<char const *>(insn.wordPointer(2))))
318 UNIMPLEMENTED("Only GLSL extended instruction set is supported");
322 case spv::OpMemberName:
324 case spv::OpSourceContinued:
325 case spv::OpSourceExtension:
328 case spv::OpModuleProcessed:
330 // No semantic impact
333 case spv::OpFunctionParameter:
334 case spv::OpFunctionCall:
335 case spv::OpSpecConstant:
336 case spv::OpSpecConstantComposite:
337 case spv::OpSpecConstantFalse:
338 case spv::OpSpecConstantOp:
339 case spv::OpSpecConstantTrue:
340 // These should have all been removed by preprocessing passes. If we see them here,
341 // our assumptions are wrong and we will probably generate wrong code.
342 UNIMPLEMENTED("%s should have already been lowered.", OpcodeName(insn.opcode()).c_str());
345 case spv::OpFConvert:
346 case spv::OpSConvert:
347 case spv::OpUConvert:
348 UNIMPLEMENTED("No valid uses for Op*Convert until we support multiple bit widths");
352 case spv::OpAccessChain:
353 case spv::OpInBoundsAccessChain:
354 case spv::OpCompositeConstruct:
355 case spv::OpCompositeInsert:
356 case spv::OpCompositeExtract:
357 case spv::OpVectorShuffle:
358 case spv::OpVectorTimesScalar:
359 case spv::OpMatrixTimesScalar:
360 case spv::OpMatrixTimesVector:
361 case spv::OpVectorTimesMatrix:
362 case spv::OpMatrixTimesMatrix:
363 case spv::OpVectorExtractDynamic:
364 case spv::OpVectorInsertDynamic:
365 case spv::OpNot: // Unary ops
368 case spv::OpLogicalNot:
369 case spv::OpIAdd: // Binary ops
380 case spv::OpFOrdEqual:
381 case spv::OpFUnordEqual:
382 case spv::OpFOrdNotEqual:
383 case spv::OpFUnordNotEqual:
384 case spv::OpFOrdLessThan:
385 case spv::OpFUnordLessThan:
386 case spv::OpFOrdGreaterThan:
387 case spv::OpFUnordGreaterThan:
388 case spv::OpFOrdLessThanEqual:
389 case spv::OpFUnordLessThanEqual:
390 case spv::OpFOrdGreaterThanEqual:
391 case spv::OpFUnordGreaterThanEqual:
396 case spv::OpINotEqual:
397 case spv::OpUGreaterThan:
398 case spv::OpSGreaterThan:
399 case spv::OpUGreaterThanEqual:
400 case spv::OpSGreaterThanEqual:
401 case spv::OpULessThan:
402 case spv::OpSLessThan:
403 case spv::OpULessThanEqual:
404 case spv::OpSLessThanEqual:
405 case spv::OpShiftRightLogical:
406 case spv::OpShiftRightArithmetic:
407 case spv::OpShiftLeftLogical:
408 case spv::OpBitwiseOr:
409 case spv::OpBitwiseXor:
410 case spv::OpBitwiseAnd:
411 case spv::OpLogicalOr:
412 case spv::OpLogicalAnd:
413 case spv::OpLogicalEqual:
414 case spv::OpLogicalNotEqual:
415 case spv::OpUMulExtended:
416 case spv::OpSMulExtended:
418 case spv::OpConvertFToU:
419 case spv::OpConvertFToS:
420 case spv::OpConvertSToF:
421 case spv::OpConvertUToF:
430 case spv::OpDPdxCoarse:
432 case spv::OpDPdyCoarse:
434 case spv::OpFwidthCoarse:
435 case spv::OpDPdxFine:
436 case spv::OpDPdyFine:
437 case spv::OpFwidthFine:
438 case spv::OpAtomicLoad:
440 // Instructions that yield an intermediate value
442 Type::ID typeId = insn.word(1);
443 Object::ID resultId = insn.word(2);
444 auto &object = defs[resultId];
445 object.type = typeId;
446 object.kind = Object::Kind::Value;
447 object.definition = insn;
449 if (insn.opcode() == spv::OpAccessChain || insn.opcode() == spv::OpInBoundsAccessChain)
451 // interior ptr has two parts:
452 // - logical base ptr, common across all lanes and known at compile time
454 Object::ID baseId = insn.word(3);
455 object.pointerBase = getObject(baseId).pointerBase;
461 case spv::OpAtomicStore:
462 // Don't need to do anything during analysis pass
466 UNIMPLEMENTED("%s", OpcodeName(insn.opcode()).c_str());
473 void SpirvShader::TraverseReachableBlocks(Block::ID id, SpirvShader::Block::Set& reachable)
475 if (reachable.count(id) == 0)
477 reachable.emplace(id);
478 for (auto out : getBlock(id).outs)
480 TraverseReachableBlocks(out, reachable);
485 void SpirvShader::AssignBlockIns()
487 Block::Set reachable;
488 TraverseReachableBlocks(mainBlockId, reachable);
490 for (auto &it : blocks)
492 auto &blockId = it.first;
493 if (reachable.count(blockId) > 0)
495 for (auto &outId : it.second.outs)
497 auto outIt = blocks.find(outId);
498 ASSERT_MSG(outIt != blocks.end(), "Block %d has a non-existent out %d", blockId.value(), outId.value());
499 auto &out = outIt->second;
500 out.ins.emplace(blockId);
506 void SpirvShader::DeclareType(InsnIterator insn)
508 Type::ID resultId = insn.word(1);
510 auto &type = types[resultId];
511 type.definition = insn;
512 type.sizeInComponents = ComputeTypeSize(insn);
514 // A structure is a builtin block if it has a builtin
515 // member. All members of such a structure are builtins.
516 switch (insn.opcode())
518 case spv::OpTypeStruct:
520 auto d = memberDecorations.find(resultId);
521 if (d != memberDecorations.end())
523 for (auto &m : d->second)
527 type.isBuiltInBlock = true;
534 case spv::OpTypePointer:
536 Type::ID elementTypeId = insn.word(3);
537 type.element = elementTypeId;
538 type.isBuiltInBlock = getType(elementTypeId).isBuiltInBlock;
539 type.storageClass = static_cast<spv::StorageClass>(insn.word(2));
542 case spv::OpTypeVector:
543 case spv::OpTypeMatrix:
544 case spv::OpTypeArray:
545 case spv::OpTypeRuntimeArray:
547 Type::ID elementTypeId = insn.word(2);
548 type.element = elementTypeId;
556 SpirvShader::Object& SpirvShader::CreateConstant(InsnIterator insn)
558 Type::ID typeId = insn.word(1);
559 Object::ID resultId = insn.word(2);
560 auto &object = defs[resultId];
561 auto &objectTy = getType(typeId);
562 object.type = typeId;
563 object.kind = Object::Kind::Constant;
564 object.definition = insn;
565 object.constantValue = std::unique_ptr<uint32_t[]>(new uint32_t[objectTy.sizeInComponents]);
569 void SpirvShader::ProcessInterfaceVariable(Object &object)
571 auto &objectTy = getType(object.type);
572 ASSERT(objectTy.storageClass == spv::StorageClassInput || objectTy.storageClass == spv::StorageClassOutput);
574 ASSERT(objectTy.opcode() == spv::OpTypePointer);
575 auto pointeeTy = getType(objectTy.element);
577 auto &builtinInterface = (objectTy.storageClass == spv::StorageClassInput) ? inputBuiltins : outputBuiltins;
578 auto &userDefinedInterface = (objectTy.storageClass == spv::StorageClassInput) ? inputs : outputs;
580 ASSERT(object.opcode() == spv::OpVariable);
581 Object::ID resultId = object.definition.word(2);
583 if (objectTy.isBuiltInBlock)
585 // walk the builtin block, registering each of its members separately.
586 auto m = memberDecorations.find(objectTy.element);
587 ASSERT(m != memberDecorations.end()); // otherwise we wouldn't have marked the type chain
588 auto &structType = pointeeTy.definition;
591 for (auto &member : m->second)
593 auto &memberType = getType(structType.word(word));
595 if (member.HasBuiltIn)
597 builtinInterface[member.BuiltIn] = {resultId, offset, memberType.sizeInComponents};
600 offset += memberType.sizeInComponents;
606 auto d = decorations.find(resultId);
607 if (d != decorations.end() && d->second.HasBuiltIn)
609 builtinInterface[d->second.BuiltIn] = {resultId, 0, pointeeTy.sizeInComponents};
613 object.kind = Object::Kind::InterfaceVariable;
614 VisitInterface(resultId,
615 [&userDefinedInterface](Decorations const &d, AttribType type) {
616 // Populate a single scalar slot in the interface from a collection of decorations and the intended component type.
617 auto scalarSlot = (d.Location << 2) | d.Component;
618 ASSERT(scalarSlot >= 0 &&
619 scalarSlot < static_cast<int32_t>(userDefinedInterface.size()));
621 auto &slot = userDefinedInterface[scalarSlot];
624 slot.NoPerspective = d.NoPerspective;
625 slot.Centroid = d.Centroid;
630 void SpirvShader::ProcessExecutionMode(InsnIterator insn)
632 auto mode = static_cast<spv::ExecutionMode>(insn.word(2));
635 case spv::ExecutionModeEarlyFragmentTests:
636 modes.EarlyFragmentTests = true;
638 case spv::ExecutionModeDepthReplacing:
639 modes.DepthReplacing = true;
641 case spv::ExecutionModeDepthGreater:
642 modes.DepthGreater = true;
644 case spv::ExecutionModeDepthLess:
645 modes.DepthLess = true;
647 case spv::ExecutionModeDepthUnchanged:
648 modes.DepthUnchanged = true;
650 case spv::ExecutionModeLocalSize:
651 modes.WorkgroupSizeX = insn.word(3);
652 modes.WorkgroupSizeY = insn.word(4);
653 modes.WorkgroupSizeZ = insn.word(5);
655 case spv::ExecutionModeOriginUpperLeft:
656 // This is always the case for a Vulkan shader. Do nothing.
659 UNIMPLEMENTED("No other execution modes are permitted");
663 uint32_t SpirvShader::ComputeTypeSize(InsnIterator insn)
665 // Types are always built from the bottom up (with the exception of forward ptrs, which
666 // don't appear in Vulkan shaders. Therefore, we can always assume our component parts have
667 // already been described (and so their sizes determined)
668 switch (insn.opcode())
670 case spv::OpTypeVoid:
671 case spv::OpTypeSampler:
672 case spv::OpTypeImage:
673 case spv::OpTypeSampledImage:
674 case spv::OpTypeFunction:
675 case spv::OpTypeRuntimeArray:
676 // Objects that don't consume any space.
677 // Descriptor-backed objects currently only need exist at compile-time.
678 // Runtime arrays don't appear in places where their size would be interesting
681 case spv::OpTypeBool:
682 case spv::OpTypeFloat:
684 // All the fundamental types are 1 component. If we ever add support for 8/16/64-bit components,
685 // we might need to change this, but only 32 bit components are required for Vulkan 1.1.
688 case spv::OpTypeVector:
689 case spv::OpTypeMatrix:
690 // Vectors and matrices both consume element count * element size.
691 return getType(insn.word(2)).sizeInComponents * insn.word(3);
693 case spv::OpTypeArray:
695 // Element count * element size. Array sizes come from constant ids.
696 auto arraySize = GetConstantInt(insn.word(3));
697 return getType(insn.word(2)).sizeInComponents * arraySize;
700 case spv::OpTypeStruct:
703 for (uint32_t i = 2u; i < insn.wordCount(); i++)
705 size += getType(insn.word(i)).sizeInComponents;
710 case spv::OpTypePointer:
711 // Runtime representation of a pointer is a per-lane index.
712 // Note: clients are expected to look through the pointer if they want the pointee size instead.
716 // Some other random insn.
717 UNIMPLEMENTED("Only types are supported");
722 bool SpirvShader::IsStorageInterleavedByLane(spv::StorageClass storageClass)
724 switch (storageClass)
726 case spv::StorageClassUniform:
727 case spv::StorageClassStorageBuffer:
728 case spv::StorageClassPushConstant:
736 int SpirvShader::VisitInterfaceInner(Type::ID id, Decorations d, F f) const
738 // Recursively walks variable definition and its type tree, taking into account
739 // any explicit Location or Component decorations encountered; where explicit
740 // Locations or Components are not specified, assigns them sequentially.
741 // Collected decorations are carried down toward the leaves and across
742 // siblings; Effect of decorations intentionally does not flow back up the tree.
744 // F is a functor to be called with the effective decoration set for every component.
746 // Returns the next available location, and calls f().
748 // This covers the rules in Vulkan 1.1 spec, 14.1.4 Location Assignment.
750 ApplyDecorationsForId(&d, id);
752 auto const &obj = getType(id);
755 case spv::OpTypePointer:
756 return VisitInterfaceInner<F>(obj.definition.word(3), d, f);
757 case spv::OpTypeMatrix:
758 for (auto i = 0u; i < obj.definition.word(3); i++, d.Location++)
760 // consumes same components of N consecutive locations
761 VisitInterfaceInner<F>(obj.definition.word(2), d, f);
764 case spv::OpTypeVector:
765 for (auto i = 0u; i < obj.definition.word(3); i++, d.Component++)
767 // consumes N consecutive components in the same location
768 VisitInterfaceInner<F>(obj.definition.word(2), d, f);
770 return d.Location + 1;
771 case spv::OpTypeFloat:
772 f(d, ATTRIBTYPE_FLOAT);
773 return d.Location + 1;
775 f(d, obj.definition.word(3) ? ATTRIBTYPE_INT : ATTRIBTYPE_UINT);
776 return d.Location + 1;
777 case spv::OpTypeBool:
778 f(d, ATTRIBTYPE_UINT);
779 return d.Location + 1;
780 case spv::OpTypeStruct:
782 // iterate over members, which may themselves have Location/Component decorations
783 for (auto i = 0u; i < obj.definition.wordCount() - 2; i++)
785 ApplyDecorationsForIdMember(&d, id, i);
786 d.Location = VisitInterfaceInner<F>(obj.definition.word(i + 2), d, f);
787 d.Component = 0; // Implicit locations always have component=0
791 case spv::OpTypeArray:
793 auto arraySize = GetConstantInt(obj.definition.word(3));
794 for (auto i = 0u; i < arraySize; i++)
796 d.Location = VisitInterfaceInner<F>(obj.definition.word(2), d, f);
801 // Intentionally partial; most opcodes do not participate in type hierarchies
807 void SpirvShader::VisitInterface(Object::ID id, F f) const
809 // Walk a variable definition and call f for each component in it.
811 ApplyDecorationsForId(&d, id);
813 auto def = getObject(id).definition;
814 ASSERT(def.opcode() == spv::OpVariable);
815 VisitInterfaceInner<F>(def.word(1), d, f);
818 SIMD::Int SpirvShader::WalkExplicitLayoutAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
820 // Produce a offset into external memory in sizeof(float) units
822 int constantOffset = 0;
823 SIMD::Int dynamicOffset = SIMD::Int(0);
824 auto &baseObject = getObject(id);
825 Type::ID typeId = getType(baseObject.type).element;
827 ApplyDecorationsForId(&d, baseObject.type);
829 // The <base> operand is an intermediate value itself, ie produced by a previous OpAccessChain.
830 // Start with its offset and build from there.
831 if (baseObject.kind == Object::Kind::Value)
833 dynamicOffset += routine->getIntermediate(id).Int(0);
836 for (auto i = 0u; i < numIndexes; i++)
838 auto & type = getType(typeId);
839 switch (type.definition.opcode())
841 case spv::OpTypeStruct:
843 int memberIndex = GetConstantInt(indexIds[i]);
844 ApplyDecorationsForIdMember(&d, typeId, memberIndex);
846 constantOffset += d.Offset / sizeof(float);
847 typeId = type.definition.word(2u + memberIndex);
850 case spv::OpTypeArray:
851 case spv::OpTypeRuntimeArray:
853 // TODO: b/127950082: Check bounds.
854 ApplyDecorationsForId(&d, typeId);
855 ASSERT(d.HasArrayStride);
856 auto & obj = getObject(indexIds[i]);
857 if (obj.kind == Object::Kind::Constant)
858 constantOffset += d.ArrayStride/sizeof(float) * GetConstantInt(indexIds[i]);
860 dynamicOffset += SIMD::Int(d.ArrayStride / sizeof(float)) * routine->getIntermediate(indexIds[i]).Int(0);
861 typeId = type.element;
864 case spv::OpTypeMatrix:
866 // TODO: b/127950082: Check bounds.
867 ApplyDecorationsForId(&d, typeId);
868 ASSERT(d.HasMatrixStride);
869 auto & obj = getObject(indexIds[i]);
870 if (obj.kind == Object::Kind::Constant)
871 constantOffset += d.MatrixStride/sizeof(float) * GetConstantInt(indexIds[i]);
873 dynamicOffset += SIMD::Int(d.MatrixStride / sizeof(float)) * routine->getIntermediate(indexIds[i]).Int(0);
874 typeId = type.element;
877 case spv::OpTypeVector:
879 auto & obj = getObject(indexIds[i]);
880 if (obj.kind == Object::Kind::Constant)
881 constantOffset += GetConstantInt(indexIds[i]);
883 dynamicOffset += routine->getIntermediate(indexIds[i]).Int(0);
884 typeId = type.element;
888 UNIMPLEMENTED("Unexpected type '%s' in WalkExplicitLayoutAccessChain", OpcodeName(type.definition.opcode()).c_str());
892 return dynamicOffset + SIMD::Int(constantOffset);
895 SIMD::Int SpirvShader::WalkAccessChain(Object::ID id, uint32_t numIndexes, uint32_t const *indexIds, SpirvRoutine *routine) const
897 // TODO: avoid doing per-lane work in some cases if we can?
898 // Produce a *component* offset into location-oriented memory
900 int constantOffset = 0;
901 SIMD::Int dynamicOffset = SIMD::Int(0);
902 auto &baseObject = getObject(id);
903 Type::ID typeId = getType(baseObject.type).element;
905 // The <base> operand is an intermediate value itself, ie produced by a previous OpAccessChain.
906 // Start with its offset and build from there.
907 if (baseObject.kind == Object::Kind::Value)
909 dynamicOffset += routine->getIntermediate(id).Int(0);
912 for (auto i = 0u; i < numIndexes; i++)
914 auto & type = getType(typeId);
915 switch(type.opcode())
917 case spv::OpTypeStruct:
919 int memberIndex = GetConstantInt(indexIds[i]);
920 int offsetIntoStruct = 0;
921 for (auto j = 0; j < memberIndex; j++) {
922 auto memberType = type.definition.word(2u + j);
923 offsetIntoStruct += getType(memberType).sizeInComponents;
925 constantOffset += offsetIntoStruct;
926 typeId = type.definition.word(2u + memberIndex);
930 case spv::OpTypeVector:
931 case spv::OpTypeMatrix:
932 case spv::OpTypeArray:
933 case spv::OpTypeRuntimeArray:
935 // TODO: b/127950082: Check bounds.
936 auto stride = getType(type.element).sizeInComponents;
937 auto & obj = getObject(indexIds[i]);
938 if (obj.kind == Object::Kind::Constant)
939 constantOffset += stride * GetConstantInt(indexIds[i]);
941 dynamicOffset += SIMD::Int(stride) * routine->getIntermediate(indexIds[i]).Int(0);
942 typeId = type.element;
947 UNIMPLEMENTED("Unexpected type '%s' in WalkAccessChain", OpcodeName(type.opcode()).c_str());
951 return dynamicOffset + SIMD::Int(constantOffset);
954 uint32_t SpirvShader::WalkLiteralAccessChain(Type::ID typeId, uint32_t numIndexes, uint32_t const *indexes) const
956 uint32_t constantOffset = 0;
958 for (auto i = 0u; i < numIndexes; i++)
960 auto & type = getType(typeId);
961 switch(type.opcode())
963 case spv::OpTypeStruct:
965 int memberIndex = indexes[i];
966 int offsetIntoStruct = 0;
967 for (auto j = 0; j < memberIndex; j++) {
968 auto memberType = type.definition.word(2u + j);
969 offsetIntoStruct += getType(memberType).sizeInComponents;
971 constantOffset += offsetIntoStruct;
972 typeId = type.definition.word(2u + memberIndex);
976 case spv::OpTypeVector:
977 case spv::OpTypeMatrix:
978 case spv::OpTypeArray:
980 auto elementType = type.definition.word(2);
981 auto stride = getType(elementType).sizeInComponents;
982 constantOffset += stride * indexes[i];
983 typeId = elementType;
988 UNIMPLEMENTED("Unexpected type in WalkLiteralAccessChain");
992 return constantOffset;
995 void SpirvShader::Decorations::Apply(spv::Decoration decoration, uint32_t arg)
999 case spv::DecorationLocation:
1001 Location = static_cast<int32_t>(arg);
1003 case spv::DecorationComponent:
1004 HasComponent = true;
1007 case spv::DecorationDescriptorSet:
1008 HasDescriptorSet = true;
1009 DescriptorSet = arg;
1011 case spv::DecorationBinding:
1015 case spv::DecorationBuiltIn:
1017 BuiltIn = static_cast<spv::BuiltIn>(arg);
1019 case spv::DecorationFlat:
1022 case spv::DecorationNoPerspective:
1023 NoPerspective = true;
1025 case spv::DecorationCentroid:
1028 case spv::DecorationBlock:
1031 case spv::DecorationBufferBlock:
1034 case spv::DecorationOffset:
1036 Offset = static_cast<int32_t>(arg);
1038 case spv::DecorationArrayStride:
1039 HasArrayStride = true;
1040 ArrayStride = static_cast<int32_t>(arg);
1042 case spv::DecorationMatrixStride:
1043 HasMatrixStride = true;
1044 MatrixStride = static_cast<int32_t>(arg);
1047 // Intentionally partial, there are many decorations we just don't care about.
1052 void SpirvShader::Decorations::Apply(const sw::SpirvShader::Decorations &src)
1054 // Apply a decoration group to this set of decorations
1058 BuiltIn = src.BuiltIn;
1061 if (src.HasLocation)
1064 Location = src.Location;
1067 if (src.HasComponent)
1069 HasComponent = true;
1070 Component = src.Component;
1073 if (src.HasDescriptorSet)
1075 HasDescriptorSet = true;
1076 DescriptorSet = src.DescriptorSet;
1082 Binding = src.Binding;
1088 Offset = src.Offset;
1091 if (src.HasArrayStride)
1093 HasArrayStride = true;
1094 ArrayStride = src.ArrayStride;
1097 if (src.HasMatrixStride)
1099 HasMatrixStride = true;
1100 MatrixStride = src.MatrixStride;
1104 NoPerspective |= src.NoPerspective;
1105 Centroid |= src.Centroid;
1107 BufferBlock |= src.BufferBlock;
1110 void SpirvShader::ApplyDecorationsForId(Decorations *d, TypeOrObjectID id) const
1112 auto it = decorations.find(id);
1113 if (it != decorations.end())
1114 d->Apply(it->second);
1117 void SpirvShader::ApplyDecorationsForIdMember(Decorations *d, Type::ID id, uint32_t member) const
1119 auto it = memberDecorations.find(id);
1120 if (it != memberDecorations.end() && member < it->second.size())
1122 d->Apply(it->second[member]);
1126 uint32_t SpirvShader::GetConstantInt(Object::ID id) const
1128 // Slightly hackish access to constants very early in translation.
1129 // General consumption of constants by other instructions should
1130 // probably be just lowered to Reactor.
1132 // TODO: not encountered yet since we only use this for array sizes etc,
1133 // but is possible to construct integer constant 0 via OpConstantNull.
1134 auto insn = getObject(id).definition;
1135 ASSERT(insn.opcode() == spv::OpConstant);
1136 ASSERT(getType(insn.word(1)).opcode() == spv::OpTypeInt);
1137 return insn.word(3);
1142 void SpirvShader::emitProlog(SpirvRoutine *routine) const
1144 for (auto insn : *this)
1146 switch (insn.opcode())
1148 case spv::OpVariable:
1150 Type::ID resultPointerTypeId = insn.word(1);
1151 auto resultPointerType = getType(resultPointerTypeId);
1152 auto pointeeType = getType(resultPointerType.element);
1154 if(pointeeType.sizeInComponents > 0) // TODO: what to do about zero-slot objects?
1156 Object::ID resultId = insn.word(2);
1157 routine->createLvalue(resultId, pointeeType.sizeInComponents);
1162 // Nothing else produces interface variables, so can all be safely ignored.
1168 void SpirvShader::emit(SpirvRoutine *routine, RValue<SIMD::Int> const &activeLaneMask) const
1171 state.setActiveLaneMask(activeLaneMask);
1172 state.routine = routine;
1174 // Emit everything up to the first label
1175 // TODO: Separate out dispatch of block from non-block instructions?
1176 for (auto insn : *this)
1178 if (insn.opcode() == spv::OpLabel)
1182 EmitInstruction(insn, &state);
1185 // Emit all the blocks starting from mainBlockId.
1186 EmitBlocks(mainBlockId, &state);
1189 void SpirvShader::EmitBlocks(Block::ID id, EmitState *state, Block::ID ignore /* = 0 */) const
1191 auto oldPending = state->pending;
1193 std::queue<Block::ID> pending;
1194 state->pending = &pending;
1196 while (pending.size() > 0)
1198 auto id = pending.front();
1201 auto const &block = getBlock(id);
1207 state->currentBlock = id;
1212 case Block::StructuredBranchConditional:
1213 case Block::UnstructuredBranchConditional:
1214 case Block::StructuredSwitch:
1215 case Block::UnstructuredSwitch:
1224 UNREACHABLE("Unexpected Block Kind: %d", int(block.kind));
1228 state->pending = oldPending;
1231 void SpirvShader::EmitInstructions(InsnIterator begin, InsnIterator end, EmitState *state) const
1233 for (auto insn = begin; insn != end; insn++)
1235 auto res = EmitInstruction(insn, state);
1238 case EmitResult::Continue:
1240 case EmitResult::Terminator:
1243 UNREACHABLE("Unexpected EmitResult %d", int(res));
1249 void SpirvShader::EmitNonLoop(EmitState *state) const
1251 auto blockId = state->currentBlock;
1252 auto block = getBlock(blockId);
1254 // Ensure all incoming blocks have been generated.
1255 auto depsDone = true;
1256 for (auto in : block.ins)
1258 if (state->visited.count(in) == 0)
1260 state->pending->emplace(in);
1267 // come back to this once the dependencies have been generated
1268 state->pending->emplace(blockId);
1272 if (!state->visited.emplace(blockId).second)
1274 return; // Already generated this block.
1277 if (blockId != mainBlockId)
1279 // Set the activeLaneMask.
1280 Intermediate activeLaneMask(1);
1281 activeLaneMask.move(0, SIMD::Int(0));
1282 for (auto in : block.ins)
1284 auto inMask = GetActiveLaneMaskEdge(state, in, blockId);
1285 activeLaneMask.replace(0, activeLaneMask.Int(0) | inMask);
1287 state->setActiveLaneMask(activeLaneMask.Int(0));
1290 EmitInstructions(block.begin(), block.end(), state);
1292 for (auto out : block.outs)
1294 state->pending->emplace(out);
1298 void SpirvShader::EmitLoop(EmitState *state) const
1300 auto blockId = state->currentBlock;
1301 auto block = getBlock(blockId);
1303 // Ensure all incoming non-back edge blocks have been generated.
1304 auto depsDone = true;
1305 for (auto in : block.ins)
1307 if (state->visited.count(in) == 0)
1309 if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
1311 state->pending->emplace(in);
1319 // come back to this once the dependencies have been generated
1320 state->pending->emplace(blockId);
1324 if (!state->visited.emplace(blockId).second)
1326 return; // Already emitted this loop.
1329 // loopActiveLaneMask is the mask of lanes that are continuing to loop.
1330 // This is initialized with the incoming active lane masks.
1331 SIMD::Int loopActiveLaneMask = SIMD::Int(0);
1332 for (auto in : block.ins)
1334 if (!existsPath(blockId, in, block.mergeBlock)) // if not a loop back edge
1336 loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
1340 // Generate an alloca for each of the loop's phis.
1341 // These will be primed with the incoming, non back edge Phi values
1342 // before the loop, and then updated just before the loop jumps back to
1346 Object::ID phiId; // The Phi identifier.
1347 Object::ID continueValue; // The source merge value from the loop.
1348 Array<SIMD::Int> storage; // The alloca.
1351 std::vector<LoopPhi> phis;
1353 // For each OpPhi between the block start and the merge instruction:
1354 for (auto insn = block.begin(); insn != block.mergeInstruction; insn++)
1356 if (insn.opcode() == spv::OpPhi)
1358 auto objectId = Object::ID(insn.word(2));
1359 auto &object = getObject(objectId);
1360 auto &type = getType(object.type);
1363 phi.phiId = Object::ID(insn.word(2));
1364 phi.storage = Array<SIMD::Int>(type.sizeInComponents);
1366 // Start with the Phi set to 0.
1367 for (uint32_t i = 0; i < type.sizeInComponents; i++)
1369 phi.storage[i] = SIMD::Int(0);
1372 // For each Phi source:
1373 for (uint32_t w = 3; w < insn.wordCount(); w += 2)
1375 auto varId = Object::ID(insn.word(w + 0));
1376 auto blockId = Block::ID(insn.word(w + 1));
1377 if (existsPath(state->currentBlock, blockId, block.mergeBlock))
1379 // This source is from a loop back-edge.
1380 ASSERT(phi.continueValue == 0 || phi.continueValue == varId);
1381 phi.continueValue = varId;
1385 // This source is from a preceding block.
1386 for (uint32_t i = 0; i < type.sizeInComponents; i++)
1388 auto in = GenericValue(this, state->routine, varId);
1389 auto mask = GetActiveLaneMaskEdge(state, blockId, state->currentBlock);
1390 phi.storage[i] = phi.storage[i] | (in.Int(i) & mask);
1395 phis.push_back(phi);
1399 // Create the loop basic blocks
1400 auto headerBasicBlock = Nucleus::createBasicBlock();
1401 auto mergeBasicBlock = Nucleus::createBasicBlock();
1403 // Start emitting code inside the loop.
1404 Nucleus::createBr(headerBasicBlock);
1405 Nucleus::setInsertBlock(headerBasicBlock);
1407 // Load the Phi values from storage.
1408 // This will load at the start of each loop.
1409 for (auto &phi : phis)
1411 auto &type = getType(getObject(phi.phiId).type);
1412 auto &dst = state->routine->createIntermediate(phi.phiId, type.sizeInComponents);
1413 for (unsigned int i = 0u; i < type.sizeInComponents; i++)
1415 dst.move(i, phi.storage[i]);
1419 // Load the active lane mask.
1420 state->setActiveLaneMask(loopActiveLaneMask);
1422 // Emit all the non-phi instructions in this loop header block.
1423 for (auto insn = block.begin(); insn != block.end(); insn++)
1425 if (insn.opcode() != spv::OpPhi)
1427 EmitInstruction(insn, state);
1431 // Emit all loop blocks, but don't emit the merge block yet.
1432 for (auto out : block.outs)
1434 if (existsPath(out, blockId, block.mergeBlock))
1436 EmitBlocks(out, state, block.mergeBlock);
1440 // Rebuild the loopActiveLaneMask from the loop back edges.
1441 loopActiveLaneMask = SIMD::Int(0);
1442 for (auto in : block.ins)
1444 if (existsPath(blockId, in, block.mergeBlock))
1446 loopActiveLaneMask |= GetActiveLaneMaskEdge(state, in, blockId);
1450 // Update loop phi values
1451 for (auto &phi : phis)
1453 if (phi.continueValue != 0)
1455 auto val = GenericValue(this, state->routine, phi.continueValue);
1456 auto &type = getType(getObject(phi.phiId).type);
1457 for (unsigned int i = 0u; i < type.sizeInComponents; i++)
1459 phi.storage[i] = val.Int(i);
1464 // Loop body now done.
1465 // If any lanes are still active, jump back to the loop header,
1466 // otherwise jump to the merge block.
1467 Nucleus::createCondBr(AnyTrue(loopActiveLaneMask).value, headerBasicBlock, mergeBasicBlock);
1469 // Continue emitting from the merge block.
1470 Nucleus::setInsertBlock(mergeBasicBlock);
1471 state->pending->emplace(block.mergeBlock);
1474 SpirvShader::EmitResult SpirvShader::EmitInstruction(InsnIterator insn, EmitState *state) const
1476 switch (insn.opcode())
1478 case spv::OpTypeVoid:
1479 case spv::OpTypeInt:
1480 case spv::OpTypeFloat:
1481 case spv::OpTypeBool:
1482 case spv::OpTypeVector:
1483 case spv::OpTypeArray:
1484 case spv::OpTypeRuntimeArray:
1485 case spv::OpTypeMatrix:
1486 case spv::OpTypeStruct:
1487 case spv::OpTypePointer:
1488 case spv::OpTypeFunction:
1489 case spv::OpExecutionMode:
1490 case spv::OpMemoryModel:
1491 case spv::OpFunction:
1492 case spv::OpFunctionEnd:
1493 case spv::OpConstant:
1494 case spv::OpConstantNull:
1495 case spv::OpConstantTrue:
1496 case spv::OpConstantFalse:
1497 case spv::OpConstantComposite:
1499 case spv::OpExtension:
1500 case spv::OpCapability:
1501 case spv::OpEntryPoint:
1502 case spv::OpExtInstImport:
1503 case spv::OpDecorate:
1504 case spv::OpMemberDecorate:
1505 case spv::OpGroupDecorate:
1506 case spv::OpGroupMemberDecorate:
1507 case spv::OpDecorationGroup:
1509 case spv::OpMemberName:
1511 case spv::OpSourceContinued:
1512 case spv::OpSourceExtension:
1515 case spv::OpModuleProcessed:
1517 // Nothing to do at emit time. These are either fully handled at analysis time,
1518 // or don't require any work at all.
1519 return EmitResult::Continue;
1522 return EmitResult::Continue;
1524 case spv::OpVariable:
1525 return EmitVariable(insn, state);
1528 case spv::OpAtomicLoad:
1529 return EmitLoad(insn, state);
1532 case spv::OpAtomicStore:
1533 return EmitStore(insn, state);
1535 case spv::OpAccessChain:
1536 case spv::OpInBoundsAccessChain:
1537 return EmitAccessChain(insn, state);
1539 case spv::OpCompositeConstruct:
1540 return EmitCompositeConstruct(insn, state);
1542 case spv::OpCompositeInsert:
1543 return EmitCompositeInsert(insn, state);
1545 case spv::OpCompositeExtract:
1546 return EmitCompositeExtract(insn, state);
1548 case spv::OpVectorShuffle:
1549 return EmitVectorShuffle(insn, state);
1551 case spv::OpVectorExtractDynamic:
1552 return EmitVectorExtractDynamic(insn, state);
1554 case spv::OpVectorInsertDynamic:
1555 return EmitVectorInsertDynamic(insn, state);
1557 case spv::OpVectorTimesScalar:
1558 case spv::OpMatrixTimesScalar:
1559 return EmitVectorTimesScalar(insn, state);
1561 case spv::OpMatrixTimesVector:
1562 return EmitMatrixTimesVector(insn, state);
1564 case spv::OpVectorTimesMatrix:
1565 return EmitVectorTimesMatrix(insn, state);
1567 case spv::OpMatrixTimesMatrix:
1568 return EmitMatrixTimesMatrix(insn, state);
1571 case spv::OpSNegate:
1572 case spv::OpFNegate:
1573 case spv::OpLogicalNot:
1574 case spv::OpConvertFToU:
1575 case spv::OpConvertFToS:
1576 case spv::OpConvertSToF:
1577 case spv::OpConvertUToF:
1578 case spv::OpBitcast:
1582 case spv::OpDPdxCoarse:
1584 case spv::OpDPdyCoarse:
1586 case spv::OpFwidthCoarse:
1587 case spv::OpDPdxFine:
1588 case spv::OpDPdyFine:
1589 case spv::OpFwidthFine:
1590 return EmitUnaryOp(insn, state);
1603 case spv::OpFOrdEqual:
1604 case spv::OpFUnordEqual:
1605 case spv::OpFOrdNotEqual:
1606 case spv::OpFUnordNotEqual:
1607 case spv::OpFOrdLessThan:
1608 case spv::OpFUnordLessThan:
1609 case spv::OpFOrdGreaterThan:
1610 case spv::OpFUnordGreaterThan:
1611 case spv::OpFOrdLessThanEqual:
1612 case spv::OpFUnordLessThanEqual:
1613 case spv::OpFOrdGreaterThanEqual:
1614 case spv::OpFUnordGreaterThanEqual:
1619 case spv::OpINotEqual:
1620 case spv::OpUGreaterThan:
1621 case spv::OpSGreaterThan:
1622 case spv::OpUGreaterThanEqual:
1623 case spv::OpSGreaterThanEqual:
1624 case spv::OpULessThan:
1625 case spv::OpSLessThan:
1626 case spv::OpULessThanEqual:
1627 case spv::OpSLessThanEqual:
1628 case spv::OpShiftRightLogical:
1629 case spv::OpShiftRightArithmetic:
1630 case spv::OpShiftLeftLogical:
1631 case spv::OpBitwiseOr:
1632 case spv::OpBitwiseXor:
1633 case spv::OpBitwiseAnd:
1634 case spv::OpLogicalOr:
1635 case spv::OpLogicalAnd:
1636 case spv::OpLogicalEqual:
1637 case spv::OpLogicalNotEqual:
1638 case spv::OpUMulExtended:
1639 case spv::OpSMulExtended:
1640 return EmitBinaryOp(insn, state);
1643 return EmitDot(insn, state);
1646 return EmitSelect(insn, state);
1648 case spv::OpExtInst:
1649 return EmitExtendedInstruction(insn, state);
1652 return EmitAny(insn, state);
1655 return EmitAll(insn, state);
1658 return EmitBranch(insn, state);
1661 return EmitPhi(insn, state);
1663 case spv::OpSelectionMerge:
1664 case spv::OpLoopMerge:
1665 return EmitResult::Continue;
1667 case spv::OpBranchConditional:
1668 return EmitBranchConditional(insn, state);
1671 return EmitSwitch(insn, state);
1673 case spv::OpUnreachable:
1674 return EmitUnreachable(insn, state);
1677 return EmitReturn(insn, state);
1680 UNIMPLEMENTED("opcode: %s", OpcodeName(insn.opcode()).c_str());
1684 return EmitResult::Continue;
1687 SpirvShader::EmitResult SpirvShader::EmitVariable(InsnIterator insn, EmitState *state) const
1689 auto routine = state->routine;
1690 Object::ID resultId = insn.word(2);
1691 auto &object = getObject(resultId);
1692 auto &objectTy = getType(object.type);
1693 switch (objectTy.storageClass)
1695 case spv::StorageClassInput:
1697 if (object.kind == Object::Kind::InterfaceVariable)
1699 auto &dst = routine->getValue(resultId);
1701 VisitInterface(resultId,
1702 [&](Decorations const &d, AttribType type) {
1703 auto scalarSlot = d.Location << 2 | d.Component;
1704 dst[offset++] = routine->inputs[scalarSlot];
1709 case spv::StorageClassUniform:
1710 case spv::StorageClassStorageBuffer:
1713 ApplyDecorationsForId(&d, resultId);
1714 ASSERT(d.DescriptorSet >= 0);
1715 ASSERT(d.Binding >= 0);
1717 size_t bindingOffset = routine->pipelineLayout->getBindingOffset(d.DescriptorSet, d.Binding);
1719 Pointer<Byte> set = routine->descriptorSets[d.DescriptorSet]; // DescriptorSet*
1720 Pointer<Byte> binding = Pointer<Byte>(set + bindingOffset); // VkDescriptorBufferInfo*
1721 Pointer<Byte> buffer = *Pointer<Pointer<Byte>>(binding + OFFSET(VkDescriptorBufferInfo, buffer)); // vk::Buffer*
1722 Pointer<Byte> data = *Pointer<Pointer<Byte>>(buffer + vk::Buffer::DataOffset); // void*
1723 Int offset = *Pointer<Int>(binding + OFFSET(VkDescriptorBufferInfo, offset));
1724 Pointer<Byte> address = data + offset;
1725 routine->physicalPointers[resultId] = address;
1728 case spv::StorageClassPushConstant:
1730 routine->physicalPointers[resultId] = routine->pushConstants;
1737 return EmitResult::Continue;
1740 SpirvShader::EmitResult SpirvShader::EmitLoad(InsnIterator insn, EmitState *state) const
1742 auto routine = state->routine;
1743 bool atomic = (insn.opcode() == spv::OpAtomicLoad);
1744 Object::ID resultId = insn.word(2);
1745 Object::ID pointerId = insn.word(3);
1746 auto &result = getObject(resultId);
1747 auto &resultTy = getType(result.type);
1748 auto &pointer = getObject(pointerId);
1749 auto &pointerBase = getObject(pointer.pointerBase);
1750 auto &pointerBaseTy = getType(pointerBase.type);
1751 std::memory_order memoryOrder = std::memory_order_relaxed;
1755 Object::ID semanticsId = insn.word(5);
1756 auto memorySemantics = static_cast<spv::MemorySemanticsMask>(getObject(semanticsId).constantValue[0]);
1757 memoryOrder = MemoryOrder(memorySemantics);
1760 ASSERT(getType(pointer.type).element == result.type);
1761 ASSERT(Type::ID(insn.word(1)) == result.type);
1762 ASSERT(!atomic || getType(getType(pointer.type).element).opcode() == spv::OpTypeInt); // Vulkan 1.1: "Atomic instructions must declare a scalar 32-bit integer type, for the value pointed to by Pointer."
1764 if (pointerBaseTy.storageClass == spv::StorageClassImage)
1766 UNIMPLEMENTED("StorageClassImage load not yet implemented");
1769 Pointer<Float> ptrBase;
1770 if (pointerBase.kind == Object::Kind::PhysicalPointer)
1772 ptrBase = routine->getPhysicalPointer(pointer.pointerBase);
1776 ptrBase = &routine->getValue(pointer.pointerBase)[0];
1779 bool interleavedByLane = IsStorageInterleavedByLane(pointerBaseTy.storageClass);
1780 auto anyInactiveLanes = AnyFalse(state->activeLaneMask());
1782 auto load = std::unique_ptr<SIMD::Float[]>(new SIMD::Float[resultTy.sizeInComponents]);
1784 If(pointer.kind == Object::Kind::Value || anyInactiveLanes)
1786 // Divergent offsets or masked lanes.
1787 auto offsets = pointer.kind == Object::Kind::Value ?
1788 As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
1789 RValue<SIMD::Int>(SIMD::Int(0));
1790 for (auto i = 0u; i < resultTy.sizeInComponents; i++)
1792 // i wish i had a Float,Float,Float,Float constructor here..
1793 for (int j = 0; j < SIMD::Width; j++)
1795 If(Extract(state->activeLaneMask(), j) != 0)
1797 Int offset = Int(i) + Extract(offsets, j);
1798 if (interleavedByLane) { offset = offset * SIMD::Width + j; }
1799 load[i] = Insert(load[i], Load(&ptrBase[offset], sizeof(float), atomic, memoryOrder), j);
1806 // No divergent offsets or masked lanes.
1807 if (interleavedByLane)
1809 // Lane-interleaved data.
1810 Pointer<SIMD::Float> src = ptrBase;
1811 for (auto i = 0u; i < resultTy.sizeInComponents; i++)
1813 load[i] = Load(&src[i], sizeof(float), atomic, memoryOrder); // TODO: optimize alignment
1818 // Non-interleaved data.
1819 for (auto i = 0u; i < resultTy.sizeInComponents; i++)
1821 load[i] = RValue<SIMD::Float>(Load(&ptrBase[i], sizeof(float), atomic, memoryOrder)); // TODO: optimize alignment
1826 auto &dst = routine->createIntermediate(resultId, resultTy.sizeInComponents);
1827 for (auto i = 0u; i < resultTy.sizeInComponents; i++)
1829 dst.move(i, load[i]);
1832 return EmitResult::Continue;
1835 SpirvShader::EmitResult SpirvShader::EmitStore(InsnIterator insn, EmitState *state) const
1837 auto routine = state->routine;
1838 bool atomic = (insn.opcode() == spv::OpAtomicStore);
1839 Object::ID pointerId = insn.word(1);
1840 Object::ID objectId = insn.word(atomic ? 4 : 2);
1841 auto &object = getObject(objectId);
1842 auto &pointer = getObject(pointerId);
1843 auto &pointerTy = getType(pointer.type);
1844 auto &elementTy = getType(pointerTy.element);
1845 auto &pointerBase = getObject(pointer.pointerBase);
1846 auto &pointerBaseTy = getType(pointerBase.type);
1847 std::memory_order memoryOrder = std::memory_order_relaxed;
1851 Object::ID semanticsId = insn.word(3);
1852 auto memorySemantics = static_cast<spv::MemorySemanticsMask>(getObject(semanticsId).constantValue[0]);
1853 memoryOrder = MemoryOrder(memorySemantics);
1856 ASSERT(!atomic || elementTy.opcode() == spv::OpTypeInt); // Vulkan 1.1: "Atomic instructions must declare a scalar 32-bit integer type, for the value pointed to by Pointer."
1858 if (pointerBaseTy.storageClass == spv::StorageClassImage)
1860 UNIMPLEMENTED("StorageClassImage store not yet implemented");
1863 Pointer<Float> ptrBase;
1864 if (pointerBase.kind == Object::Kind::PhysicalPointer)
1866 ptrBase = routine->getPhysicalPointer(pointer.pointerBase);
1870 ptrBase = &routine->getValue(pointer.pointerBase)[0];
1873 bool interleavedByLane = IsStorageInterleavedByLane(pointerBaseTy.storageClass);
1874 auto anyInactiveLanes = AnyFalse(state->activeLaneMask());
1876 if (object.kind == Object::Kind::Constant)
1878 // Constant source data.
1879 auto src = reinterpret_cast<float *>(object.constantValue.get());
1880 If(pointer.kind == Object::Kind::Value || anyInactiveLanes)
1882 // Divergent offsets or masked lanes.
1883 auto offsets = pointer.kind == Object::Kind::Value ?
1884 As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
1885 RValue<SIMD::Int>(SIMD::Int(0));
1886 for (auto i = 0u; i < elementTy.sizeInComponents; i++)
1888 for (int j = 0; j < SIMD::Width; j++)
1890 If(Extract(state->activeLaneMask(), j) != 0)
1892 Int offset = Int(i) + Extract(offsets, j);
1893 if (interleavedByLane) { offset = offset * SIMD::Width + j; }
1894 Store(RValue<Float>(src[i]), &ptrBase[offset], sizeof(float), atomic, memoryOrder);
1901 // Constant source data.
1902 // No divergent offsets or masked lanes.
1903 Pointer<SIMD::Float> dst = ptrBase;
1904 for (auto i = 0u; i < elementTy.sizeInComponents; i++)
1906 Store(RValue<SIMD::Float>(src[i]), &dst[i], sizeof(float), atomic, memoryOrder); // TODO: optimize alignment
1912 // Intermediate source data.
1913 auto &src = routine->getIntermediate(objectId);
1914 If(pointer.kind == Object::Kind::Value || anyInactiveLanes)
1916 // Divergent offsets or masked lanes.
1917 auto offsets = pointer.kind == Object::Kind::Value ?
1918 As<SIMD::Int>(routine->getIntermediate(pointerId).Int(0)) :
1919 RValue<SIMD::Int>(SIMD::Int(0));
1920 for (auto i = 0u; i < elementTy.sizeInComponents; i++)
1922 for (int j = 0; j < SIMD::Width; j++)
1924 If(Extract(state->activeLaneMask(), j) != 0)
1926 Int offset = Int(i) + Extract(offsets, j);
1927 if (interleavedByLane) { offset = offset * SIMD::Width + j; }
1928 Store(Extract(src.Float(i), j), &ptrBase[offset], sizeof(float), atomic, memoryOrder);
1935 // No divergent offsets or masked lanes.
1936 if (interleavedByLane)
1938 // Lane-interleaved data.
1939 Pointer<SIMD::Float> dst = ptrBase;
1940 for (auto i = 0u; i < elementTy.sizeInComponents; i++)
1942 Store(src.Float(i), &dst[i], sizeof(float), atomic, memoryOrder); // TODO: optimize alignment
1947 // Intermediate source data. Non-interleaved data.
1948 Pointer<SIMD::Float> dst = ptrBase;
1949 for (auto i = 0u; i < elementTy.sizeInComponents; i++)
1951 Store<SIMD::Float>(SIMD::Float(src.Float(i)), &dst[i], sizeof(float), atomic, memoryOrder); // TODO: optimize alignment
1957 return EmitResult::Continue;
1960 SpirvShader::EmitResult SpirvShader::EmitAccessChain(InsnIterator insn, EmitState *state) const
1962 auto routine = state->routine;
1963 Type::ID typeId = insn.word(1);
1964 Object::ID resultId = insn.word(2);
1965 Object::ID baseId = insn.word(3);
1966 uint32_t numIndexes = insn.wordCount() - 4;
1967 const uint32_t *indexes = insn.wordPointer(4);
1968 auto &type = getType(typeId);
1969 ASSERT(type.sizeInComponents == 1);
1970 ASSERT(getObject(baseId).pointerBase == getObject(resultId).pointerBase);
1972 auto &dst = routine->createIntermediate(resultId, type.sizeInComponents);
1974 if(type.storageClass == spv::StorageClassPushConstant ||
1975 type.storageClass == spv::StorageClassUniform ||
1976 type.storageClass == spv::StorageClassStorageBuffer)
1978 dst.move(0, WalkExplicitLayoutAccessChain(baseId, numIndexes, indexes, routine));
1982 dst.move(0, WalkAccessChain(baseId, numIndexes, indexes, routine));
1985 return EmitResult::Continue;
1988 SpirvShader::EmitResult SpirvShader::EmitCompositeConstruct(InsnIterator insn, EmitState *state) const
1990 auto routine = state->routine;
1991 auto &type = getType(insn.word(1));
1992 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
1995 for (auto i = 0u; i < insn.wordCount() - 3; i++)
1997 Object::ID srcObjectId = insn.word(3u + i);
1998 auto & srcObject = getObject(srcObjectId);
1999 auto & srcObjectTy = getType(srcObject.type);
2000 GenericValue srcObjectAccess(this, routine, srcObjectId);
2002 for (auto j = 0u; j < srcObjectTy.sizeInComponents; j++)
2004 dst.move(offset++, srcObjectAccess.Float(j));
2008 return EmitResult::Continue;
2011 SpirvShader::EmitResult SpirvShader::EmitCompositeInsert(InsnIterator insn, EmitState *state) const
2013 auto routine = state->routine;
2014 Type::ID resultTypeId = insn.word(1);
2015 auto &type = getType(resultTypeId);
2016 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2017 auto &newPartObject = getObject(insn.word(3));
2018 auto &newPartObjectTy = getType(newPartObject.type);
2019 auto firstNewComponent = WalkLiteralAccessChain(resultTypeId, insn.wordCount() - 5, insn.wordPointer(5));
2021 GenericValue srcObjectAccess(this, routine, insn.word(4));
2022 GenericValue newPartObjectAccess(this, routine, insn.word(3));
2024 // old components before
2025 for (auto i = 0u; i < firstNewComponent; i++)
2027 dst.move(i, srcObjectAccess.Float(i));
2030 for (auto i = 0u; i < newPartObjectTy.sizeInComponents; i++)
2032 dst.move(firstNewComponent + i, newPartObjectAccess.Float(i));
2034 // old components after
2035 for (auto i = firstNewComponent + newPartObjectTy.sizeInComponents; i < type.sizeInComponents; i++)
2037 dst.move(i, srcObjectAccess.Float(i));
2040 return EmitResult::Continue;
2043 SpirvShader::EmitResult SpirvShader::EmitCompositeExtract(InsnIterator insn, EmitState *state) const
2045 auto routine = state->routine;
2046 auto &type = getType(insn.word(1));
2047 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2048 auto &compositeObject = getObject(insn.word(3));
2049 Type::ID compositeTypeId = compositeObject.definition.word(1);
2050 auto firstComponent = WalkLiteralAccessChain(compositeTypeId, insn.wordCount() - 4, insn.wordPointer(4));
2052 GenericValue compositeObjectAccess(this, routine, insn.word(3));
2053 for (auto i = 0u; i < type.sizeInComponents; i++)
2055 dst.move(i, compositeObjectAccess.Float(firstComponent + i));
2058 return EmitResult::Continue;
2061 SpirvShader::EmitResult SpirvShader::EmitVectorShuffle(InsnIterator insn, EmitState *state) const
2063 auto routine = state->routine;
2064 auto &type = getType(insn.word(1));
2065 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2067 // Note: number of components in result type, first half type, and second
2068 // half type are all independent.
2069 auto &firstHalfType = getType(getObject(insn.word(3)).type);
2071 GenericValue firstHalfAccess(this, routine, insn.word(3));
2072 GenericValue secondHalfAccess(this, routine, insn.word(4));
2074 for (auto i = 0u; i < type.sizeInComponents; i++)
2076 auto selector = insn.word(5 + i);
2077 if (selector == static_cast<uint32_t>(-1))
2079 // Undefined value. Until we decide to do real undef values, zero is as good
2081 dst.move(i, RValue<SIMD::Float>(0.0f));
2083 else if (selector < firstHalfType.sizeInComponents)
2085 dst.move(i, firstHalfAccess.Float(selector));
2089 dst.move(i, secondHalfAccess.Float(selector - firstHalfType.sizeInComponents));
2093 return EmitResult::Continue;
2096 SpirvShader::EmitResult SpirvShader::EmitVectorExtractDynamic(InsnIterator insn, EmitState *state) const
2098 auto routine = state->routine;
2099 auto &type = getType(insn.word(1));
2100 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2101 auto &srcType = getType(getObject(insn.word(3)).type);
2103 GenericValue src(this, routine, insn.word(3));
2104 GenericValue index(this, routine, insn.word(4));
2106 SIMD::UInt v = SIMD::UInt(0);
2108 for (auto i = 0u; i < srcType.sizeInComponents; i++)
2110 v |= CmpEQ(index.UInt(0), SIMD::UInt(i)) & src.UInt(i);
2114 return EmitResult::Continue;
2117 SpirvShader::EmitResult SpirvShader::EmitVectorInsertDynamic(InsnIterator insn, EmitState *state) const
2119 auto routine = state->routine;
2120 auto &type = getType(insn.word(1));
2121 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2123 GenericValue src(this, routine, insn.word(3));
2124 GenericValue component(this, routine, insn.word(4));
2125 GenericValue index(this, routine, insn.word(5));
2127 for (auto i = 0u; i < type.sizeInComponents; i++)
2129 SIMD::UInt mask = CmpEQ(SIMD::UInt(i), index.UInt(0));
2130 dst.move(i, (src.UInt(i) & ~mask) | (component.UInt(0) & mask));
2132 return EmitResult::Continue;
2135 SpirvShader::EmitResult SpirvShader::EmitVectorTimesScalar(InsnIterator insn, EmitState *state) const
2137 auto routine = state->routine;
2138 auto &type = getType(insn.word(1));
2139 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2140 auto lhs = GenericValue(this, routine, insn.word(3));
2141 auto rhs = GenericValue(this, routine, insn.word(4));
2143 for (auto i = 0u; i < type.sizeInComponents; i++)
2145 dst.move(i, lhs.Float(i) * rhs.Float(0));
2148 return EmitResult::Continue;
2151 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesVector(InsnIterator insn, EmitState *state) const
2153 auto routine = state->routine;
2154 auto &type = getType(insn.word(1));
2155 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2156 auto lhs = GenericValue(this, routine, insn.word(3));
2157 auto rhs = GenericValue(this, routine, insn.word(4));
2158 auto rhsType = getType(getObject(insn.word(4)).type);
2160 for (auto i = 0u; i < type.sizeInComponents; i++)
2162 SIMD::Float v = lhs.Float(i) * rhs.Float(0);
2163 for (auto j = 1u; j < rhsType.sizeInComponents; j++)
2165 v += lhs.Float(i + type.sizeInComponents * j) * rhs.Float(j);
2170 return EmitResult::Continue;
2173 SpirvShader::EmitResult SpirvShader::EmitVectorTimesMatrix(InsnIterator insn, EmitState *state) const
2175 auto routine = state->routine;
2176 auto &type = getType(insn.word(1));
2177 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2178 auto lhs = GenericValue(this, routine, insn.word(3));
2179 auto rhs = GenericValue(this, routine, insn.word(4));
2180 auto lhsType = getType(getObject(insn.word(3)).type);
2182 for (auto i = 0u; i < type.sizeInComponents; i++)
2184 SIMD::Float v = lhs.Float(0) * rhs.Float(i * lhsType.sizeInComponents);
2185 for (auto j = 1u; j < lhsType.sizeInComponents; j++)
2187 v += lhs.Float(j) * rhs.Float(i * lhsType.sizeInComponents + j);
2192 return EmitResult::Continue;
2195 SpirvShader::EmitResult SpirvShader::EmitMatrixTimesMatrix(InsnIterator insn, EmitState *state) const
2197 auto routine = state->routine;
2198 auto &type = getType(insn.word(1));
2199 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2200 auto lhs = GenericValue(this, routine, insn.word(3));
2201 auto rhs = GenericValue(this, routine, insn.word(4));
2203 auto numColumns = type.definition.word(3);
2204 auto numRows = getType(type.definition.word(2)).definition.word(3);
2205 auto numAdds = getType(getObject(insn.word(3)).type).definition.word(3);
2207 for (auto row = 0u; row < numRows; row++)
2209 for (auto col = 0u; col < numColumns; col++)
2211 SIMD::Float v = SIMD::Float(0);
2212 for (auto i = 0u; i < numAdds; i++)
2214 v += lhs.Float(i * numRows + row) * rhs.Float(col * numAdds + i);
2216 dst.move(numRows * col + row, v);
2220 return EmitResult::Continue;
2223 SpirvShader::EmitResult SpirvShader::EmitUnaryOp(InsnIterator insn, EmitState *state) const
2225 auto routine = state->routine;
2226 auto &type = getType(insn.word(1));
2227 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2228 auto src = GenericValue(this, routine, insn.word(3));
2230 for (auto i = 0u; i < type.sizeInComponents; i++)
2232 switch (insn.opcode())
2235 case spv::OpLogicalNot: // logical not == bitwise not due to all-bits boolean representation
2236 dst.move(i, ~src.UInt(i));
2238 case spv::OpSNegate:
2239 dst.move(i, -src.Int(i));
2241 case spv::OpFNegate:
2242 dst.move(i, -src.Float(i));
2244 case spv::OpConvertFToU:
2245 dst.move(i, SIMD::UInt(src.Float(i)));
2247 case spv::OpConvertFToS:
2248 dst.move(i, SIMD::Int(src.Float(i)));
2250 case spv::OpConvertSToF:
2251 dst.move(i, SIMD::Float(src.Int(i)));
2253 case spv::OpConvertUToF:
2254 dst.move(i, SIMD::Float(src.UInt(i)));
2256 case spv::OpBitcast:
2257 dst.move(i, src.Float(i));
2260 dst.move(i, IsInf(src.Float(i)));
2263 dst.move(i, IsNan(src.Float(i)));
2266 case spv::OpDPdxCoarse:
2267 // Derivative instructions: FS invocations are laid out like so:
2270 static_assert(SIMD::Width == 4, "All cross-lane instructions will need care when using a different width");
2271 dst.move(i, SIMD::Float(Extract(src.Float(i), 1) - Extract(src.Float(i), 0)));
2274 case spv::OpDPdyCoarse:
2275 dst.move(i, SIMD::Float(Extract(src.Float(i), 2) - Extract(src.Float(i), 0)));
2278 case spv::OpFwidthCoarse:
2279 dst.move(i, SIMD::Float(Abs(Extract(src.Float(i), 1) - Extract(src.Float(i), 0))
2280 + Abs(Extract(src.Float(i), 2) - Extract(src.Float(i), 0))));
2282 case spv::OpDPdxFine:
2284 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
2285 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
2286 SIMD::Float v = SIMD::Float(firstRow);
2287 v = Insert(v, secondRow, 2);
2288 v = Insert(v, secondRow, 3);
2292 case spv::OpDPdyFine:
2294 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
2295 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
2296 SIMD::Float v = SIMD::Float(firstColumn);
2297 v = Insert(v, secondColumn, 1);
2298 v = Insert(v, secondColumn, 3);
2302 case spv::OpFwidthFine:
2304 auto firstRow = Extract(src.Float(i), 1) - Extract(src.Float(i), 0);
2305 auto secondRow = Extract(src.Float(i), 3) - Extract(src.Float(i), 2);
2306 SIMD::Float dpdx = SIMD::Float(firstRow);
2307 dpdx = Insert(dpdx, secondRow, 2);
2308 dpdx = Insert(dpdx, secondRow, 3);
2309 auto firstColumn = Extract(src.Float(i), 2) - Extract(src.Float(i), 0);
2310 auto secondColumn = Extract(src.Float(i), 3) - Extract(src.Float(i), 1);
2311 SIMD::Float dpdy = SIMD::Float(firstColumn);
2312 dpdy = Insert(dpdy, secondColumn, 1);
2313 dpdy = Insert(dpdy, secondColumn, 3);
2314 dst.move(i, Abs(dpdx) + Abs(dpdy));
2318 UNIMPLEMENTED("Unhandled unary operator %s", OpcodeName(insn.opcode()).c_str());
2322 return EmitResult::Continue;
2325 SpirvShader::EmitResult SpirvShader::EmitBinaryOp(InsnIterator insn, EmitState *state) const
2327 auto routine = state->routine;
2328 auto &type = getType(insn.word(1));
2329 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2330 auto &lhsType = getType(getObject(insn.word(3)).type);
2331 auto lhs = GenericValue(this, routine, insn.word(3));
2332 auto rhs = GenericValue(this, routine, insn.word(4));
2334 for (auto i = 0u; i < lhsType.sizeInComponents; i++)
2336 switch (insn.opcode())
2339 dst.move(i, lhs.Int(i) + rhs.Int(i));
2342 dst.move(i, lhs.Int(i) - rhs.Int(i));
2345 dst.move(i, lhs.Int(i) * rhs.Int(i));
2349 SIMD::Int a = lhs.Int(i);
2350 SIMD::Int b = rhs.Int(i);
2351 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
2352 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
2358 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
2359 dst.move(i, lhs.UInt(i) / (rhs.UInt(i) | zeroMask));
2364 SIMD::Int a = lhs.Int(i);
2365 SIMD::Int b = rhs.Int(i);
2366 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
2367 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
2373 SIMD::Int a = lhs.Int(i);
2374 SIMD::Int b = rhs.Int(i);
2375 b = b | CmpEQ(b, SIMD::Int(0)); // prevent divide-by-zero
2376 a = a | (CmpEQ(a, SIMD::Int(0x80000000)) & CmpEQ(b, SIMD::Int(-1))); // prevent integer overflow
2378 // If a and b have opposite signs, the remainder operation takes
2379 // the sign from a but OpSMod is supposed to take the sign of b.
2380 // Adding b will ensure that the result has the correct sign and
2381 // that it is still congruent to a modulo b.
2383 // See also http://mathforum.org/library/drmath/view/52343.html
2384 auto signDiff = CmpNEQ(CmpGE(a, SIMD::Int(0)), CmpGE(b, SIMD::Int(0)));
2385 auto fixedMod = mod + (b & CmpNEQ(mod, SIMD::Int(0)) & signDiff);
2386 dst.move(i, As<SIMD::Float>(fixedMod));
2391 auto zeroMask = As<SIMD::UInt>(CmpEQ(rhs.Int(i), SIMD::Int(0)));
2392 dst.move(i, lhs.UInt(i) % (rhs.UInt(i) | zeroMask));
2396 case spv::OpLogicalEqual:
2397 dst.move(i, CmpEQ(lhs.Int(i), rhs.Int(i)));
2399 case spv::OpINotEqual:
2400 case spv::OpLogicalNotEqual:
2401 dst.move(i, CmpNEQ(lhs.Int(i), rhs.Int(i)));
2403 case spv::OpUGreaterThan:
2404 dst.move(i, CmpGT(lhs.UInt(i), rhs.UInt(i)));
2406 case spv::OpSGreaterThan:
2407 dst.move(i, CmpGT(lhs.Int(i), rhs.Int(i)));
2409 case spv::OpUGreaterThanEqual:
2410 dst.move(i, CmpGE(lhs.UInt(i), rhs.UInt(i)));
2412 case spv::OpSGreaterThanEqual:
2413 dst.move(i, CmpGE(lhs.Int(i), rhs.Int(i)));
2415 case spv::OpULessThan:
2416 dst.move(i, CmpLT(lhs.UInt(i), rhs.UInt(i)));
2418 case spv::OpSLessThan:
2419 dst.move(i, CmpLT(lhs.Int(i), rhs.Int(i)));
2421 case spv::OpULessThanEqual:
2422 dst.move(i, CmpLE(lhs.UInt(i), rhs.UInt(i)));
2424 case spv::OpSLessThanEqual:
2425 dst.move(i, CmpLE(lhs.Int(i), rhs.Int(i)));
2428 dst.move(i, lhs.Float(i) + rhs.Float(i));
2431 dst.move(i, lhs.Float(i) - rhs.Float(i));
2434 dst.move(i, lhs.Float(i) * rhs.Float(i));
2437 dst.move(i, lhs.Float(i) / rhs.Float(i));
2440 // TODO(b/126873455): inaccurate for values greater than 2^24
2441 dst.move(i, lhs.Float(i) - rhs.Float(i) * Floor(lhs.Float(i) / rhs.Float(i)));
2444 dst.move(i, lhs.Float(i) % rhs.Float(i));
2446 case spv::OpFOrdEqual:
2447 dst.move(i, CmpEQ(lhs.Float(i), rhs.Float(i)));
2449 case spv::OpFUnordEqual:
2450 dst.move(i, CmpUEQ(lhs.Float(i), rhs.Float(i)));
2452 case spv::OpFOrdNotEqual:
2453 dst.move(i, CmpNEQ(lhs.Float(i), rhs.Float(i)));
2455 case spv::OpFUnordNotEqual:
2456 dst.move(i, CmpUNEQ(lhs.Float(i), rhs.Float(i)));
2458 case spv::OpFOrdLessThan:
2459 dst.move(i, CmpLT(lhs.Float(i), rhs.Float(i)));
2461 case spv::OpFUnordLessThan:
2462 dst.move(i, CmpULT(lhs.Float(i), rhs.Float(i)));
2464 case spv::OpFOrdGreaterThan:
2465 dst.move(i, CmpGT(lhs.Float(i), rhs.Float(i)));
2467 case spv::OpFUnordGreaterThan:
2468 dst.move(i, CmpUGT(lhs.Float(i), rhs.Float(i)));
2470 case spv::OpFOrdLessThanEqual:
2471 dst.move(i, CmpLE(lhs.Float(i), rhs.Float(i)));
2473 case spv::OpFUnordLessThanEqual:
2474 dst.move(i, CmpULE(lhs.Float(i), rhs.Float(i)));
2476 case spv::OpFOrdGreaterThanEqual:
2477 dst.move(i, CmpGE(lhs.Float(i), rhs.Float(i)));
2479 case spv::OpFUnordGreaterThanEqual:
2480 dst.move(i, CmpUGE(lhs.Float(i), rhs.Float(i)));
2482 case spv::OpShiftRightLogical:
2483 dst.move(i, lhs.UInt(i) >> rhs.UInt(i));
2485 case spv::OpShiftRightArithmetic:
2486 dst.move(i, lhs.Int(i) >> rhs.Int(i));
2488 case spv::OpShiftLeftLogical:
2489 dst.move(i, lhs.UInt(i) << rhs.UInt(i));
2491 case spv::OpBitwiseOr:
2492 case spv::OpLogicalOr:
2493 dst.move(i, lhs.UInt(i) | rhs.UInt(i));
2495 case spv::OpBitwiseXor:
2496 dst.move(i, lhs.UInt(i) ^ rhs.UInt(i));
2498 case spv::OpBitwiseAnd:
2499 case spv::OpLogicalAnd:
2500 dst.move(i, lhs.UInt(i) & rhs.UInt(i));
2502 case spv::OpSMulExtended:
2503 // Extended ops: result is a structure containing two members of the same type as lhs & rhs.
2504 // In our flat view then, component i is the i'th component of the first member;
2505 // component i + N is the i'th component of the second member.
2506 dst.move(i, lhs.Int(i) * rhs.Int(i));
2507 dst.move(i + lhsType.sizeInComponents, MulHigh(lhs.Int(i), rhs.Int(i)));
2509 case spv::OpUMulExtended:
2510 dst.move(i, lhs.UInt(i) * rhs.UInt(i));
2511 dst.move(i + lhsType.sizeInComponents, MulHigh(lhs.UInt(i), rhs.UInt(i)));
2514 UNIMPLEMENTED("Unhandled binary operator %s", OpcodeName(insn.opcode()).c_str());
2518 return EmitResult::Continue;
2521 SpirvShader::EmitResult SpirvShader::EmitDot(InsnIterator insn, EmitState *state) const
2523 auto routine = state->routine;
2524 auto &type = getType(insn.word(1));
2525 ASSERT(type.sizeInComponents == 1);
2526 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2527 auto &lhsType = getType(getObject(insn.word(3)).type);
2528 auto lhs = GenericValue(this, routine, insn.word(3));
2529 auto rhs = GenericValue(this, routine, insn.word(4));
2531 dst.move(0, Dot(lhsType.sizeInComponents, lhs, rhs));
2532 return EmitResult::Continue;
2535 SpirvShader::EmitResult SpirvShader::EmitSelect(InsnIterator insn, EmitState *state) const
2537 auto routine = state->routine;
2538 auto &type = getType(insn.word(1));
2539 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2540 auto cond = GenericValue(this, routine, insn.word(3));
2541 auto lhs = GenericValue(this, routine, insn.word(4));
2542 auto rhs = GenericValue(this, routine, insn.word(5));
2544 for (auto i = 0u; i < type.sizeInComponents; i++)
2546 dst.move(i, (cond.Int(i) & lhs.Int(i)) | (~cond.Int(i) & rhs.Int(i))); // FIXME: IfThenElse()
2549 return EmitResult::Continue;
2552 SpirvShader::EmitResult SpirvShader::EmitExtendedInstruction(InsnIterator insn, EmitState *state) const
2554 auto routine = state->routine;
2555 auto &type = getType(insn.word(1));
2556 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2557 auto extInstIndex = static_cast<GLSLstd450>(insn.word(4));
2559 switch (extInstIndex)
2561 case GLSLstd450FAbs:
2563 auto src = GenericValue(this, routine, insn.word(5));
2564 for (auto i = 0u; i < type.sizeInComponents; i++)
2566 dst.move(i, Abs(src.Float(i)));
2570 case GLSLstd450SAbs:
2572 auto src = GenericValue(this, routine, insn.word(5));
2573 for (auto i = 0u; i < type.sizeInComponents; i++)
2575 dst.move(i, Abs(src.Int(i)));
2579 case GLSLstd450Cross:
2581 auto lhs = GenericValue(this, routine, insn.word(5));
2582 auto rhs = GenericValue(this, routine, insn.word(6));
2583 dst.move(0, lhs.Float(1) * rhs.Float(2) - rhs.Float(1) * lhs.Float(2));
2584 dst.move(1, lhs.Float(2) * rhs.Float(0) - rhs.Float(2) * lhs.Float(0));
2585 dst.move(2, lhs.Float(0) * rhs.Float(1) - rhs.Float(0) * lhs.Float(1));
2588 case GLSLstd450Floor:
2590 auto src = GenericValue(this, routine, insn.word(5));
2591 for (auto i = 0u; i < type.sizeInComponents; i++)
2593 dst.move(i, Floor(src.Float(i)));
2597 case GLSLstd450Trunc:
2599 auto src = GenericValue(this, routine, insn.word(5));
2600 for (auto i = 0u; i < type.sizeInComponents; i++)
2602 dst.move(i, Trunc(src.Float(i)));
2606 case GLSLstd450Ceil:
2608 auto src = GenericValue(this, routine, insn.word(5));
2609 for (auto i = 0u; i < type.sizeInComponents; i++)
2611 dst.move(i, Ceil(src.Float(i)));
2615 case GLSLstd450Fract:
2617 auto src = GenericValue(this, routine, insn.word(5));
2618 for (auto i = 0u; i < type.sizeInComponents; i++)
2620 dst.move(i, Frac(src.Float(i)));
2624 case GLSLstd450Round:
2626 auto src = GenericValue(this, routine, insn.word(5));
2627 for (auto i = 0u; i < type.sizeInComponents; i++)
2629 dst.move(i, Round(src.Float(i)));
2633 case GLSLstd450RoundEven:
2635 auto src = GenericValue(this, routine, insn.word(5));
2636 for (auto i = 0u; i < type.sizeInComponents; i++)
2638 auto x = Round(src.Float(i));
2639 // dst = round(src) + ((round(src) < src) * 2 - 1) * (fract(src) == 0.5) * isOdd(round(src));
2640 dst.move(i, x + ((SIMD::Float(CmpLT(x, src.Float(i)) & SIMD::Int(1)) * SIMD::Float(2.0f)) - SIMD::Float(1.0f)) *
2641 SIMD::Float(CmpEQ(Frac(src.Float(i)), SIMD::Float(0.5f)) & SIMD::Int(1)) * SIMD::Float(Int4(x) & SIMD::Int(1)));
2645 case GLSLstd450FMin:
2647 auto lhs = GenericValue(this, routine, insn.word(5));
2648 auto rhs = GenericValue(this, routine, insn.word(6));
2649 for (auto i = 0u; i < type.sizeInComponents; i++)
2651 dst.move(i, Min(lhs.Float(i), rhs.Float(i)));
2655 case GLSLstd450FMax:
2657 auto lhs = GenericValue(this, routine, insn.word(5));
2658 auto rhs = GenericValue(this, routine, insn.word(6));
2659 for (auto i = 0u; i < type.sizeInComponents; i++)
2661 dst.move(i, Max(lhs.Float(i), rhs.Float(i)));
2665 case GLSLstd450SMin:
2667 auto lhs = GenericValue(this, routine, insn.word(5));
2668 auto rhs = GenericValue(this, routine, insn.word(6));
2669 for (auto i = 0u; i < type.sizeInComponents; i++)
2671 dst.move(i, Min(lhs.Int(i), rhs.Int(i)));
2675 case GLSLstd450SMax:
2677 auto lhs = GenericValue(this, routine, insn.word(5));
2678 auto rhs = GenericValue(this, routine, insn.word(6));
2679 for (auto i = 0u; i < type.sizeInComponents; i++)
2681 dst.move(i, Max(lhs.Int(i), rhs.Int(i)));
2685 case GLSLstd450UMin:
2687 auto lhs = GenericValue(this, routine, insn.word(5));
2688 auto rhs = GenericValue(this, routine, insn.word(6));
2689 for (auto i = 0u; i < type.sizeInComponents; i++)
2691 dst.move(i, Min(lhs.UInt(i), rhs.UInt(i)));
2695 case GLSLstd450UMax:
2697 auto lhs = GenericValue(this, routine, insn.word(5));
2698 auto rhs = GenericValue(this, routine, insn.word(6));
2699 for (auto i = 0u; i < type.sizeInComponents; i++)
2701 dst.move(i, Max(lhs.UInt(i), rhs.UInt(i)));
2705 case GLSLstd450Step:
2707 auto edge = GenericValue(this, routine, insn.word(5));
2708 auto x = GenericValue(this, routine, insn.word(6));
2709 for (auto i = 0u; i < type.sizeInComponents; i++)
2711 dst.move(i, CmpNLT(x.Float(i), edge.Float(i)) & As<SIMD::Int>(SIMD::Float(1.0f)));
2715 case GLSLstd450SmoothStep:
2717 auto edge0 = GenericValue(this, routine, insn.word(5));
2718 auto edge1 = GenericValue(this, routine, insn.word(6));
2719 auto x = GenericValue(this, routine, insn.word(7));
2720 for (auto i = 0u; i < type.sizeInComponents; i++)
2722 auto tx = Min(Max((x.Float(i) - edge0.Float(i)) /
2723 (edge1.Float(i) - edge0.Float(i)), SIMD::Float(0.0f)), SIMD::Float(1.0f));
2724 dst.move(i, tx * tx * (Float4(3.0f) - Float4(2.0f) * tx));
2728 case GLSLstd450FMix:
2730 auto x = GenericValue(this, routine, insn.word(5));
2731 auto y = GenericValue(this, routine, insn.word(6));
2732 auto a = GenericValue(this, routine, insn.word(7));
2733 for (auto i = 0u; i < type.sizeInComponents; i++)
2735 dst.move(i, a.Float(i) * (y.Float(i) - x.Float(i)) + x.Float(i));
2739 case GLSLstd450FClamp:
2741 auto x = GenericValue(this, routine, insn.word(5));
2742 auto minVal = GenericValue(this, routine, insn.word(6));
2743 auto maxVal = GenericValue(this, routine, insn.word(7));
2744 for (auto i = 0u; i < type.sizeInComponents; i++)
2746 dst.move(i, Min(Max(x.Float(i), minVal.Float(i)), maxVal.Float(i)));
2750 case GLSLstd450SClamp:
2752 auto x = GenericValue(this, routine, insn.word(5));
2753 auto minVal = GenericValue(this, routine, insn.word(6));
2754 auto maxVal = GenericValue(this, routine, insn.word(7));
2755 for (auto i = 0u; i < type.sizeInComponents; i++)
2757 dst.move(i, Min(Max(x.Int(i), minVal.Int(i)), maxVal.Int(i)));
2761 case GLSLstd450UClamp:
2763 auto x = GenericValue(this, routine, insn.word(5));
2764 auto minVal = GenericValue(this, routine, insn.word(6));
2765 auto maxVal = GenericValue(this, routine, insn.word(7));
2766 for (auto i = 0u; i < type.sizeInComponents; i++)
2768 dst.move(i, Min(Max(x.UInt(i), minVal.UInt(i)), maxVal.UInt(i)));
2772 case GLSLstd450FSign:
2774 auto src = GenericValue(this, routine, insn.word(5));
2775 for (auto i = 0u; i < type.sizeInComponents; i++)
2777 auto neg = As<SIMD::Int>(CmpLT(src.Float(i), SIMD::Float(-0.0f))) & As<SIMD::Int>(SIMD::Float(-1.0f));
2778 auto pos = As<SIMD::Int>(CmpNLE(src.Float(i), SIMD::Float(+0.0f))) & As<SIMD::Int>(SIMD::Float(1.0f));
2779 dst.move(i, neg | pos);
2783 case GLSLstd450SSign:
2785 auto src = GenericValue(this, routine, insn.word(5));
2786 for (auto i = 0u; i < type.sizeInComponents; i++)
2788 auto neg = CmpLT(src.Int(i), SIMD::Int(0)) & SIMD::Int(-1);
2789 auto pos = CmpNLE(src.Int(i), SIMD::Int(0)) & SIMD::Int(1);
2790 dst.move(i, neg | pos);
2794 case GLSLstd450Reflect:
2796 auto I = GenericValue(this, routine, insn.word(5));
2797 auto N = GenericValue(this, routine, insn.word(6));
2799 SIMD::Float d = Dot(type.sizeInComponents, I, N);
2801 for (auto i = 0u; i < type.sizeInComponents; i++)
2803 dst.move(i, I.Float(i) - SIMD::Float(2.0f) * d * N.Float(i));
2807 case GLSLstd450Refract:
2809 auto I = GenericValue(this, routine, insn.word(5));
2810 auto N = GenericValue(this, routine, insn.word(6));
2811 auto eta = GenericValue(this, routine, insn.word(7));
2813 SIMD::Float d = Dot(type.sizeInComponents, I, N);
2814 SIMD::Float k = SIMD::Float(1.0f) - eta.Float(0) * eta.Float(0) * (SIMD::Float(1.0f) - d * d);
2815 SIMD::Int pos = CmpNLT(k, SIMD::Float(0.0f));
2816 SIMD::Float t = (eta.Float(0) * d + Sqrt(k));
2818 for (auto i = 0u; i < type.sizeInComponents; i++)
2820 dst.move(i, pos & As<SIMD::Int>(eta.Float(0) * I.Float(i) - t * N.Float(i)));
2824 case GLSLstd450FaceForward:
2826 auto N = GenericValue(this, routine, insn.word(5));
2827 auto I = GenericValue(this, routine, insn.word(6));
2828 auto Nref = GenericValue(this, routine, insn.word(7));
2830 SIMD::Float d = Dot(type.sizeInComponents, I, Nref);
2831 SIMD::Int neg = CmpLT(d, SIMD::Float(0.0f));
2833 for (auto i = 0u; i < type.sizeInComponents; i++)
2835 auto n = N.Float(i);
2836 dst.move(i, (neg & As<SIMD::Int>(n)) | (~neg & As<SIMD::Int>(-n)));
2840 case GLSLstd450Length:
2842 auto x = GenericValue(this, routine, insn.word(5));
2843 SIMD::Float d = Dot(getType(getObject(insn.word(5)).type).sizeInComponents, x, x);
2845 dst.move(0, Sqrt(d));
2848 case GLSLstd450Normalize:
2850 auto x = GenericValue(this, routine, insn.word(5));
2851 SIMD::Float d = Dot(getType(getObject(insn.word(5)).type).sizeInComponents, x, x);
2852 SIMD::Float invLength = SIMD::Float(1.0f) / Sqrt(d);
2854 for (auto i = 0u; i < type.sizeInComponents; i++)
2856 dst.move(i, invLength * x.Float(i));
2860 case GLSLstd450Distance:
2862 auto p0 = GenericValue(this, routine, insn.word(5));
2863 auto p1 = GenericValue(this, routine, insn.word(6));
2864 auto p0Type = getType(getObject(insn.word(5)).type);
2866 // sqrt(dot(p0-p1, p0-p1))
2867 SIMD::Float d = (p0.Float(0) - p1.Float(0)) * (p0.Float(0) - p1.Float(0));
2869 for (auto i = 1u; i < p0Type.sizeInComponents; i++)
2871 d += (p0.Float(i) - p1.Float(i)) * (p0.Float(i) - p1.Float(i));
2874 dst.move(0, Sqrt(d));
2878 UNIMPLEMENTED("Unhandled ExtInst %d", extInstIndex);
2881 return EmitResult::Continue;
2884 std::memory_order SpirvShader::MemoryOrder(spv::MemorySemanticsMask memorySemantics)
2886 switch(memorySemantics)
2888 case spv::MemorySemanticsMaskNone: return std::memory_order_relaxed;
2889 case spv::MemorySemanticsAcquireMask: return std::memory_order_acquire;
2890 case spv::MemorySemanticsReleaseMask: return std::memory_order_release;
2891 case spv::MemorySemanticsAcquireReleaseMask: return std::memory_order_acq_rel;
2892 case spv::MemorySemanticsSequentiallyConsistentMask: return std::memory_order_acq_rel; // Vulkan 1.1: "SequentiallyConsistent is treated as AcquireRelease"
2894 UNREACHABLE("MemorySemanticsMask %x", memorySemantics);
2895 return std::memory_order_acq_rel;
2899 SIMD::Float SpirvShader::Dot(unsigned numComponents, GenericValue const & x, GenericValue const & y) const
2901 SIMD::Float d = x.Float(0) * y.Float(0);
2903 for (auto i = 1u; i < numComponents; i++)
2905 d += x.Float(i) * y.Float(i);
2911 SpirvShader::EmitResult SpirvShader::EmitAny(InsnIterator insn, EmitState *state) const
2913 auto routine = state->routine;
2914 auto &type = getType(insn.word(1));
2915 ASSERT(type.sizeInComponents == 1);
2916 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2917 auto &srcType = getType(getObject(insn.word(3)).type);
2918 auto src = GenericValue(this, routine, insn.word(3));
2920 SIMD::UInt result = src.UInt(0);
2922 for (auto i = 1u; i < srcType.sizeInComponents; i++)
2924 result |= src.UInt(i);
2927 dst.move(0, result);
2928 return EmitResult::Continue;
2931 SpirvShader::EmitResult SpirvShader::EmitAll(InsnIterator insn, EmitState *state) const
2933 auto routine = state->routine;
2934 auto &type = getType(insn.word(1));
2935 ASSERT(type.sizeInComponents == 1);
2936 auto &dst = routine->createIntermediate(insn.word(2), type.sizeInComponents);
2937 auto &srcType = getType(getObject(insn.word(3)).type);
2938 auto src = GenericValue(this, routine, insn.word(3));
2940 SIMD::UInt result = src.UInt(0);
2942 for (auto i = 1u; i < srcType.sizeInComponents; i++)
2944 result &= src.UInt(i);
2947 dst.move(0, result);
2948 return EmitResult::Continue;
2951 SpirvShader::EmitResult SpirvShader::EmitBranch(InsnIterator insn, EmitState *state) const
2953 auto target = Block::ID(insn.word(1));
2954 auto edge = Block::Edge{state->currentBlock, target};
2955 state->edgeActiveLaneMasks.emplace(edge, state->activeLaneMask());
2956 return EmitResult::Terminator;
2959 SpirvShader::EmitResult SpirvShader::EmitBranchConditional(InsnIterator insn, EmitState *state) const
2961 auto block = getBlock(state->currentBlock);
2962 ASSERT(block.branchInstruction == insn);
2964 auto condId = Object::ID(block.branchInstruction.word(1));
2965 auto trueBlockId = Block::ID(block.branchInstruction.word(2));
2966 auto falseBlockId = Block::ID(block.branchInstruction.word(3));
2968 auto cond = GenericValue(this, state->routine, condId);
2969 ASSERT_MSG(getType(getObject(condId).type).sizeInComponents == 1, "Condition must be a Boolean type scalar");
2971 // TODO: Optimize for case where all lanes take same path.
2973 state->addOutputActiveLaneMaskEdge(trueBlockId, cond.Int(0));
2974 state->addOutputActiveLaneMaskEdge(falseBlockId, ~cond.Int(0));
2976 return EmitResult::Terminator;
2979 SpirvShader::EmitResult SpirvShader::EmitSwitch(InsnIterator insn, EmitState *state) const
2981 auto block = getBlock(state->currentBlock);
2982 ASSERT(block.branchInstruction == insn);
2984 auto selId = Object::ID(block.branchInstruction.word(1));
2986 auto sel = GenericValue(this, state->routine, selId);
2987 ASSERT_MSG(getType(getObject(selId).type).sizeInComponents == 1, "Selector must be a scalar");
2989 auto numCases = (block.branchInstruction.wordCount() - 3) / 2;
2991 // TODO: Optimize for case where all lanes take same path.
2993 SIMD::Int defaultLaneMask = state->activeLaneMask();
2995 // Gather up the case label matches and calculate defaultLaneMask.
2996 std::vector<RValue<SIMD::Int>> caseLabelMatches;
2997 caseLabelMatches.reserve(numCases);
2998 for (uint32_t i = 0; i < numCases; i++)
3000 auto label = block.branchInstruction.word(i * 2 + 3);
3001 auto caseBlockId = Block::ID(block.branchInstruction.word(i * 2 + 4));
3002 auto caseLabelMatch = CmpEQ(sel.Int(0), SIMD::Int(label));
3003 state->addOutputActiveLaneMaskEdge(caseBlockId, caseLabelMatch);
3004 defaultLaneMask &= ~caseLabelMatch;
3007 auto defaultBlockId = Block::ID(block.branchInstruction.word(2));
3008 state->addOutputActiveLaneMaskEdge(defaultBlockId, defaultLaneMask);
3010 return EmitResult::Terminator;
3013 SpirvShader::EmitResult SpirvShader::EmitUnreachable(InsnIterator insn, EmitState *state) const
3015 // TODO: Log something in this case?
3016 state->setActiveLaneMask(SIMD::Int(0));
3017 return EmitResult::Terminator;
3020 SpirvShader::EmitResult SpirvShader::EmitReturn(InsnIterator insn, EmitState *state) const
3022 state->setActiveLaneMask(SIMD::Int(0));
3023 return EmitResult::Terminator;
3026 SpirvShader::EmitResult SpirvShader::EmitPhi(InsnIterator insn, EmitState *state) const
3028 auto routine = state->routine;
3029 auto typeId = Type::ID(insn.word(1));
3030 auto type = getType(typeId);
3031 auto objectId = Object::ID(insn.word(2));
3033 auto &dst = routine->createIntermediate(objectId, type.sizeInComponents);
3036 for (uint32_t w = 3; w < insn.wordCount(); w += 2)
3038 auto varId = Object::ID(insn.word(w + 0));
3039 auto blockId = Block::ID(insn.word(w + 1));
3041 auto in = GenericValue(this, routine, varId);
3042 auto mask = GetActiveLaneMaskEdge(state, blockId, state->currentBlock);
3044 for (uint32_t i = 0; i < type.sizeInComponents; i++)
3046 auto inMasked = in.Int(i) & mask;
3047 dst.replace(i, first ? inMasked : (dst.Int(i) | inMasked));
3052 return EmitResult::Continue;
3055 void SpirvShader::emitEpilog(SpirvRoutine *routine) const
3057 for (auto insn : *this)
3059 switch (insn.opcode())
3061 case spv::OpVariable:
3063 Object::ID resultId = insn.word(2);
3064 auto &object = getObject(resultId);
3065 auto &objectTy = getType(object.type);
3066 if (object.kind == Object::Kind::InterfaceVariable && objectTy.storageClass == spv::StorageClassOutput)
3068 auto &dst = routine->getValue(resultId);
3070 VisitInterface(resultId,
3071 [&](Decorations const &d, AttribType type) {
3072 auto scalarSlot = d.Location << 2 | d.Component;
3073 routine->outputs[scalarSlot] = dst[offset++];
3084 SpirvShader::Block::Block(InsnIterator begin, InsnIterator end) : begin_(begin), end_(end)
3086 // Default to a Simple, this may change later.
3087 kind = Block::Simple;
3089 // Walk the instructions to find the last two of the block.
3090 InsnIterator insns[2];
3091 for (auto insn : *this)
3093 insns[0] = insns[1];
3097 switch (insns[1].opcode())
3100 branchInstruction = insns[1];
3101 outs.emplace(Block::ID(branchInstruction.word(1)));
3103 switch (insns[0].opcode())
3105 case spv::OpLoopMerge:
3107 mergeInstruction = insns[0];
3108 mergeBlock = Block::ID(mergeInstruction.word(1));
3109 continueTarget = Block::ID(mergeInstruction.word(2));
3113 kind = Block::Simple;
3118 case spv::OpBranchConditional:
3119 branchInstruction = insns[1];
3120 outs.emplace(Block::ID(branchInstruction.word(2)));
3121 outs.emplace(Block::ID(branchInstruction.word(3)));
3123 switch (insns[0].opcode())
3125 case spv::OpSelectionMerge:
3126 kind = StructuredBranchConditional;
3127 mergeInstruction = insns[0];
3128 mergeBlock = Block::ID(mergeInstruction.word(1));
3131 case spv::OpLoopMerge:
3133 mergeInstruction = insns[0];
3134 mergeBlock = Block::ID(mergeInstruction.word(1));
3135 continueTarget = Block::ID(mergeInstruction.word(2));
3139 kind = UnstructuredBranchConditional;
3145 branchInstruction = insns[1];
3146 outs.emplace(Block::ID(branchInstruction.word(2)));
3147 for (uint32_t w = 4; w < branchInstruction.wordCount(); w += 2)
3149 outs.emplace(Block::ID(branchInstruction.word(w)));
3152 switch (insns[0].opcode())
3154 case spv::OpSelectionMerge:
3155 kind = StructuredSwitch;
3156 mergeInstruction = insns[0];
3157 mergeBlock = Block::ID(mergeInstruction.word(1));
3161 kind = UnstructuredSwitch;
3171 bool SpirvShader::existsPath(Block::ID from, Block::ID to, Block::ID notPassingThrough) const
3173 // TODO: Optimize: This can be cached on the block.
3175 seen.emplace(notPassingThrough);
3177 std::queue<Block::ID> pending;
3178 pending.emplace(from);
3180 while (pending.size() > 0)
3182 auto id = pending.front();
3184 for (auto out : getBlock(id).outs)
3186 if (seen.count(out) != 0) { continue; }
3187 if (out == to) { return true; }
3188 pending.emplace(out);
3196 void SpirvShader::EmitState::addOutputActiveLaneMaskEdge(Block::ID to, RValue<SIMD::Int> mask)
3198 addActiveLaneMaskEdge(currentBlock, to, mask & activeLaneMask());
3201 void SpirvShader::EmitState::addActiveLaneMaskEdge(Block::ID from, Block::ID to, RValue<SIMD::Int> mask)
3203 auto edge = Block::Edge{from, to};
3204 auto it = edgeActiveLaneMasks.find(edge);
3205 if (it == edgeActiveLaneMasks.end())
3207 edgeActiveLaneMasks.emplace(edge, mask);
3211 auto combined = it->second | mask;
3212 edgeActiveLaneMasks.erase(edge);
3213 edgeActiveLaneMasks.emplace(edge, combined);
3217 RValue<SIMD::Int> SpirvShader::GetActiveLaneMaskEdge(EmitState *state, Block::ID from, Block::ID to) const
3219 auto edge = Block::Edge{from, to};
3220 auto it = state->edgeActiveLaneMasks.find(edge);
3221 ASSERT_MSG(it != state->edgeActiveLaneMasks.end(), "Could not find edge %d -> %d", from.value(), to.value());
3225 SpirvRoutine::SpirvRoutine(vk::PipelineLayout const *pipelineLayout) :
3226 pipelineLayout(pipelineLayout)