return false;
}
- unsigned int iterations = loopCount(node);
+ LoopInfo loop(node);
- if(iterations == 0)
+ if(loop.iterations == 0)
{
return false;
}
- bool unroll = (iterations <= 4);
-
- if(unroll)
- {
- LoopUnrollable loopUnrollable;
- unroll = loopUnrollable.traverse(node);
- }
+ bool unroll = (loop.iterations <= 4);
TIntermNode *init = node->getInit();
TIntermTyped *condition = node->getCondition();
if(unroll)
{
- for(unsigned int i = 0; i < iterations; i++)
+ for(unsigned int i = 0; i < loop.iterations; i++)
{
// condition->traverse(this); // Condition could contain statements, but not in an unrollable loop
return matrix->getSecondarySize();
}
- // Returns ~0u if no loop count could be determined
- unsigned int OutputASM::loopCount(TIntermLoop *node)
+ // Sets iterations to ~0u if no loop count could be statically determined.
+ OutputASM::LoopInfo::LoopInfo(TIntermLoop *node)
{
// Parse loops of the form:
- // for(int index = initial; index [comparator] limit; index += increment)
- TIntermSymbol *index = 0;
- TOperator comparator = EOpNull;
- int initial = 0;
- int limit = 0;
- int increment = 0;
+ // for(int index = initial; index [comparator] limit; index [op] increment)
// Parse index name and intial value
if(node->getInit())
if(binaryTerminal)
{
- TOperator op = binaryTerminal->getOp();
- TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
+ TIntermSymbol *operand = binaryTerminal->getLeft()->getAsSymbolNode();
- if(constant)
+ if(operand && operand->getId() == index->getId())
{
- if(constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
- {
- int value = constant->getUnionArrayPointer()[0].getIConst();
+ TOperator op = binaryTerminal->getOp();
+ TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
- switch(op)
+ if(constant)
+ {
+ if(constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
{
- case EOpAddAssign: increment = value; break;
- case EOpSubAssign: increment = -value; break;
- default: UNIMPLEMENTED();
+ int value = constant->getUnionArrayPointer()[0].getIConst();
+
+ switch(op)
+ {
+ case EOpAddAssign: increment = value; break;
+ case EOpSubAssign: increment = -value; break;
+ default: increment = 0; break; // Rare cases left unhandled. Treated as non-deterministic.
+ }
}
}
}
}
else if(unaryTerminal)
{
- TOperator op = unaryTerminal->getOp();
+ TIntermSymbol *operand = unaryTerminal->getOperand()->getAsSymbolNode();
- switch(op)
+ if(operand && operand->getId() == index->getId())
{
- case EOpPostIncrement: increment = 1; break;
- case EOpPostDecrement: increment = -1; break;
- case EOpPreIncrement: increment = 1; break;
- case EOpPreDecrement: increment = -1; break;
- default: UNIMPLEMENTED();
+ TOperator op = unaryTerminal->getOp();
+
+ switch(op)
+ {
+ case EOpPostIncrement: increment = 1; break;
+ case EOpPostDecrement: increment = -1; break;
+ case EOpPreIncrement: increment = 1; break;
+ case EOpPreDecrement: increment = -1; break;
+ default: increment = 0; break; // Rare cases left unhandled. Treated as non-deterministic.
+ }
}
}
}
if(index && comparator != EOpNull && increment != 0)
{
+ // Check the loop body for return statements or changes to the index variable that make it non-deterministic.
+ LoopUnrollable loopUnrollable;
+ bool unrollable = loopUnrollable.traverse(node, index->getId());
+
+ if(!unrollable)
+ {
+ iterations = ~0u;
+ return;
+ }
+
if(comparator == EOpLessThanEqual)
{
comparator = EOpLessThan;
{
if(!(initial < limit)) // Never loops
{
- return 0;
+ iterations = 0;
}
-
- int iterations = (limit - initial + abs(increment) - 1) / increment; // Ceiling division
-
- if(iterations < 0)
+ else if(increment < 0)
{
- return ~0u;
+ iterations = ~0u;
+ }
+ else
+ {
+ iterations = (limit - initial + abs(increment) - 1) / increment; // Ceiling division
}
-
- return iterations;
}
- else UNIMPLEMENTED(); // Falls through
+ else
+ {
+ // Rare cases left unhandled. Treated as non-deterministic.
+ iterations = ~0u;
+ }
}
-
- return ~0u;
}
- bool LoopUnrollable::traverse(TIntermNode *node)
+ bool LoopUnrollable::traverse(TIntermNode *node, int indexId)
{
- loopDepth = 0;
loopUnrollable = true;
+ loopDepth = 0;
+ loopIndexId = indexId;
+
node->traverse(this);
return loopUnrollable;
return true;
}
+ void LoopUnrollable::visitSymbol(TIntermSymbol *node)
+ {
+ // Check that the loop index is not used as the argument to a function out or inout parameter.
+ if(node->getId() == loopIndexId)
+ {
+ if(node->getQualifier() == EvqOut || node->getQualifier() == EvqInOut)
+ {
+ loopUnrollable = false;
+ }
+ }
+ }
+
+ bool LoopUnrollable::visitBinary(Visit visit, TIntermBinary *node)
+ {
+ if(!loopUnrollable)
+ {
+ return false;
+ }
+
+ // Check that the loop index is not statically assigned to.
+ TIntermSymbol *symbol = node->getLeft()->getAsSymbolNode();
+ loopUnrollable = node->modifiesState() && symbol && (symbol->getId() == loopIndexId);
+
+ return loopUnrollable;
+ }
+
+ bool LoopUnrollable::visitUnary(Visit visit, TIntermUnary *node)
+ {
+ if(!loopUnrollable)
+ {
+ return false;
+ }
+
+ // Check that the loop index is not statically assigned to.
+ TIntermSymbol *symbol = node->getOperand()->getAsSymbolNode();
+ loopUnrollable = node->modifiesState() && symbol && (symbol->getId() == loopIndexId);
+
+ return loopUnrollable;
+ }
+
bool LoopUnrollable::visitBranch(Visit visit, TIntermBranch *node)
{
if(!loopUnrollable)
static int dim(TIntermNode *v);
static int dim2(TIntermNode *m);
- static unsigned int loopCount(TIntermLoop *node);
+
+ struct LoopInfo
+ {
+ LoopInfo(TIntermLoop *node);
+
+ bool isDeterministic()
+ {
+ return (iterations != ~0u);
+ }
+
+ unsigned int iterations = ~0u;
+
+ TIntermSymbol *index = nullptr;
+ TOperator comparator = EOpNull;
+ int initial = 0;
+ int limit = 0;
+ int increment = 0;
+ };
Shader *const shaderObject;
sw::Shader *shader;
class LoopUnrollable : public TIntermTraverser
{
public:
- bool traverse(TIntermNode *node);
+ bool traverse(TIntermNode *node, int loopIndexId);
private:
- bool visitBranch(Visit visit, TIntermBranch *node);
- bool visitLoop(Visit visit, TIntermLoop *loop);
- bool visitAggregate(Visit visit, TIntermAggregate *node);
+ void visitSymbol(TIntermSymbol *node) override;
+ bool visitBinary(Visit visit, TIntermBinary *node) override;
+ bool visitUnary(Visit visit, TIntermUnary *node) override;
+ bool visitBranch(Visit visit, TIntermBranch *node) override;
+ bool visitLoop(Visit visit, TIntermLoop *loop) override;
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override;
- int loopDepth;
bool loopUnrollable;
+
+ int loopDepth;
+ int loopIndexId;
};
}