OSDN Git Service

Detect loop index modifications in loop body.
authorNicolas Capens <capn@google.com>
Mon, 28 May 2018 16:25:57 +0000 (12:25 -0400)
committerNicolas Capens <nicolascapens@google.com>
Tue, 29 May 2018 13:59:43 +0000 (13:59 +0000)
Loops can only be unrolled if their loop index variable is not being
modified in the loop body.

Also check that the increment step of the loop operates on the initial
index variable.

Also remove some UNIMPLEMENTED's that were benign.

Bug chromium:845103
Bug chromium:843867
Bug skia:7846

Change-Id: Ib2b39f2d58763f0299ce7f6f75a8a75e6bdc7963
Reviewed-on: https://swiftshader-review.googlesource.com/18988
Reviewed-by: Alexis Hétu <sugoi@google.com>
Tested-by: Nicolas Capens <nicolascapens@google.com>
src/OpenGL/compiler/OutputASM.cpp
src/OpenGL/compiler/OutputASM.h

index b6b0eef..cbfbf56 100644 (file)
@@ -1824,20 +1824,14 @@ namespace glsl
                        return false;
                }
 
-               unsigned int iterations = loopCount(node);
+               LoopInfo loop(node);
 
-               if(iterations == 0)
+               if(loop.iterations == 0)
                {
                        return false;
                }
 
-               bool unroll = (iterations <= 4);
-
-               if(unroll)
-               {
-                       LoopUnrollable loopUnrollable;
-                       unroll = loopUnrollable.traverse(node);
-               }
+               bool unroll = (loop.iterations <= 4);
 
                TIntermNode *init = node->getInit();
                TIntermTyped *condition = node->getCondition();
@@ -1873,7 +1867,7 @@ namespace glsl
 
                        if(unroll)
                        {
-                               for(unsigned int i = 0; i < iterations; i++)
+                               for(unsigned int i = 0; i < loop.iterations; i++)
                                {
                                //      condition->traverse(this);   // Condition could contain statements, but not in an unrollable loop
 
@@ -3716,16 +3710,11 @@ namespace glsl
                return matrix->getSecondarySize();
        }
 
-       // Returns ~0u if no loop count could be determined
-       unsigned int OutputASM::loopCount(TIntermLoop *node)
+       // Sets iterations to ~0u if no loop count could be statically determined.
+       OutputASM::LoopInfo::LoopInfo(TIntermLoop *node)
        {
                // Parse loops of the form:
-               // for(int index = initial; index [comparator] limit; index += increment)
-               TIntermSymbol *index = 0;
-               TOperator comparator = EOpNull;
-               int initial = 0;
-               int limit = 0;
-               int increment = 0;
+               // for(int index = initial; index [comparator] limit; index [op] increment)
 
                // Parse index name and intial value
                if(node->getInit())
@@ -3788,41 +3777,61 @@ namespace glsl
 
                        if(binaryTerminal)
                        {
-                               TOperator op = binaryTerminal->getOp();
-                               TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
+                               TIntermSymbol *operand = binaryTerminal->getLeft()->getAsSymbolNode();
 
-                               if(constant)
+                               if(operand && operand->getId() == index->getId())
                                {
-                                       if(constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
-                                       {
-                                               int value = constant->getUnionArrayPointer()[0].getIConst();
+                                       TOperator op = binaryTerminal->getOp();
+                                       TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
 
-                                               switch(op)
+                                       if(constant)
+                                       {
+                                               if(constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
                                                {
-                                               case EOpAddAssign: increment = value;  break;
-                                               case EOpSubAssign: increment = -value; break;
-                                               default: UNIMPLEMENTED();
+                                                       int value = constant->getUnionArrayPointer()[0].getIConst();
+
+                                                       switch(op)
+                                                       {
+                                                       case EOpAddAssign: increment = value;  break;
+                                                       case EOpSubAssign: increment = -value; break;
+                                                       default:           increment = 0;      break;   // Rare cases left unhandled. Treated as non-deterministic.
+                                                       }
                                                }
                                        }
                                }
                        }
                        else if(unaryTerminal)
                        {
-                               TOperator op = unaryTerminal->getOp();
+                               TIntermSymbol *operand = unaryTerminal->getOperand()->getAsSymbolNode();
 
-                               switch(op)
+                               if(operand && operand->getId() == index->getId())
                                {
-                               case EOpPostIncrement: increment = 1;  break;
-                               case EOpPostDecrement: increment = -1; break;
-                               case EOpPreIncrement:  increment = 1;  break;
-                               case EOpPreDecrement:  increment = -1; break;
-                               default: UNIMPLEMENTED();
+                                       TOperator op = unaryTerminal->getOp();
+
+                                       switch(op)
+                                       {
+                                       case EOpPostIncrement: increment = 1;  break;
+                                       case EOpPostDecrement: increment = -1; break;
+                                       case EOpPreIncrement:  increment = 1;  break;
+                                       case EOpPreDecrement:  increment = -1; break;
+                                       default:               increment = 0;  break;   // Rare cases left unhandled. Treated as non-deterministic.
+                                       }
                                }
                        }
                }
 
                if(index && comparator != EOpNull && increment != 0)
                {
+                       // Check the loop body for return statements or changes to the index variable that make it non-deterministic.
+                       LoopUnrollable loopUnrollable;
+                       bool unrollable = loopUnrollable.traverse(node, index->getId());
+
+                       if(!unrollable)
+                       {
+                               iterations = ~0u;
+                               return;
+                       }
+
                        if(comparator == EOpLessThanEqual)
                        {
                                comparator = EOpLessThan;
@@ -3846,29 +3855,32 @@ namespace glsl
                        {
                                if(!(initial < limit))   // Never loops
                                {
-                                       return 0;
+                                       iterations = 0;
                                }
-
-                               int iterations = (limit - initial + abs(increment) - 1) / increment;   // Ceiling division
-
-                               if(iterations < 0)
+                               else if(increment < 0)
                                {
-                                       return ~0u;
+                                       iterations = ~0u;
+                               }
+                               else
+                               {
+                                       iterations = (limit - initial + abs(increment) - 1) / increment;   // Ceiling division
                                }
-
-                               return iterations;
                        }
-                       else UNIMPLEMENTED();   // Falls through
+                       else
+                       {
+                               // Rare cases left unhandled. Treated as non-deterministic.
+                               iterations = ~0u;
+                       }
                }
-
-               return ~0u;
        }
 
-       bool LoopUnrollable::traverse(TIntermNode *node)
+       bool LoopUnrollable::traverse(TIntermNode *node, int indexId)
        {
-               loopDepth = 0;
                loopUnrollable = true;
 
+               loopDepth = 0;
+               loopIndexId = indexId;
+
                node->traverse(this);
 
                return loopUnrollable;
@@ -3888,6 +3900,46 @@ namespace glsl
                return true;
        }
 
+       void LoopUnrollable::visitSymbol(TIntermSymbol *node)
+       {
+               // Check that the loop index is not used as the argument to a function out or inout parameter.
+               if(node->getId() == loopIndexId)
+               {
+                       if(node->getQualifier() == EvqOut || node->getQualifier() == EvqInOut)
+                       {
+                               loopUnrollable = false;
+                       }
+               }
+       }
+
+       bool LoopUnrollable::visitBinary(Visit visit, TIntermBinary *node)
+       {
+               if(!loopUnrollable)
+               {
+                       return false;
+               }
+
+               // Check that the loop index is not statically assigned to.
+               TIntermSymbol *symbol = node->getLeft()->getAsSymbolNode();
+               loopUnrollable = node->modifiesState() && symbol && (symbol->getId() == loopIndexId);
+
+               return loopUnrollable;
+       }
+
+       bool LoopUnrollable::visitUnary(Visit visit, TIntermUnary *node)
+       {
+               if(!loopUnrollable)
+               {
+                       return false;
+               }
+
+               // Check that the loop index is not statically assigned to.
+               TIntermSymbol *symbol = node->getOperand()->getAsSymbolNode();
+               loopUnrollable = node->modifiesState() && symbol && (symbol->getId() == loopIndexId);
+
+               return loopUnrollable;
+       }
+
        bool LoopUnrollable::visitBranch(Visit visit, TIntermBranch *node)
        {
                if(!loopUnrollable)
index 3391035..4480e2a 100644 (file)
@@ -316,7 +316,24 @@ namespace glsl
 
                static int dim(TIntermNode *v);
                static int dim2(TIntermNode *m);
-               static unsigned int loopCount(TIntermLoop *node);
+
+               struct LoopInfo
+               {
+                       LoopInfo(TIntermLoop *node);
+
+                       bool isDeterministic()
+                       {
+                               return (iterations != ~0u);
+                       }
+
+                       unsigned int iterations = ~0u;
+
+                       TIntermSymbol *index = nullptr;
+                       TOperator comparator = EOpNull;
+                       int initial = 0;
+                       int limit = 0;
+                       int increment = 0;
+               };
 
                Shader *const shaderObject;
                sw::Shader *shader;
@@ -363,15 +380,20 @@ namespace glsl
        class LoopUnrollable : public TIntermTraverser
        {
        public:
-               bool traverse(TIntermNode *node);
+               bool traverse(TIntermNode *node, int loopIndexId);
 
        private:
-               bool visitBranch(Visit visit, TIntermBranch *node);
-               bool visitLoop(Visit visit, TIntermLoop *loop);
-               bool visitAggregate(Visit visit, TIntermAggregate *node);
+               void visitSymbol(TIntermSymbol *node) override;
+               bool visitBinary(Visit visit, TIntermBinary *node) override;
+               bool visitUnary(Visit visit, TIntermUnary *node) override;
+               bool visitBranch(Visit visit, TIntermBranch *node) override;
+               bool visitLoop(Visit visit, TIntermLoop *loop) override;
+               bool visitAggregate(Visit visit, TIntermAggregate *node) override;
 
-               int loopDepth;
                bool loopUnrollable;
+
+               int loopDepth;
+               int loopIndexId;
        };
 }