OSDN Git Service

Turn floating point IVs into integer IVs where possible.
authorDevang Patel <dpatel@apple.com>
Mon, 3 Nov 2008 18:32:19 +0000 (18:32 +0000)
committerDevang Patel <dpatel@apple.com>
Mon, 3 Nov 2008 18:32:19 +0000 (18:32 +0000)
This allows SCEV users to effectively calculate trip count.
LSR later on transforms back integer IVs to floating point IVs
later on to avoid int-to-float casts inside the loop.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@58625 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/IndVarSimplify.cpp
test/Transforms/IndVarsSimplify/2008-11-03-Floating.ll [new file with mode: 0644]

index 4dfd8b9..56829bd 100644 (file)
@@ -95,6 +95,7 @@ namespace {
     void DeleteTriviallyDeadInstructions(std::set<Instruction*> &Insts);
 
     void OptimizeCanonicalIVType(Loop *L);
+    void HandleFloatingPointIV(Loop *L);
   };
 }
 
@@ -466,6 +467,7 @@ bool IndVarSimplify::runOnLoop(Loop *L, LPPassManager &LPM) {
   // auxillary induction variables.
   std::vector<std::pair<PHINode*, SCEVHandle> > IndVars;
 
+  HandleFloatingPointIV(L);
   for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
     PHINode *PN = cast<PHINode>(I);
     if (PN->getType()->isInteger()) { // FIXME: when we have fast-math, enable!
@@ -718,3 +720,151 @@ void IndVarSimplify::OptimizeCanonicalIVType(Loop *L) {
   Incr->eraseFromParent();
 }
 
+/// HandleFloatingPointIV - If the loop has floating induction variable
+/// then insert corresponding integer induction variable if possible.
+void IndVarSimplify::HandleFloatingPointIV(Loop *L) {
+  BasicBlock *Header = L->getHeader();
+  SmallVector <PHINode *, 4> FPHIs;
+  Instruction *NonPHIInsn = NULL;
+
+  // Collect all floating point IVs first.
+  BasicBlock::iterator I = Header->begin();
+  while(true) {
+    if (!isa<PHINode>(I)) {
+      NonPHIInsn = I;
+      break;
+    }
+    PHINode *PH = cast<PHINode>(I);
+    if (PH->getType()->isFloatingPoint())
+      FPHIs.push_back(PH);
+    ++I;
+  }
+   
+  for (SmallVector<PHINode *, 4>::iterator I = FPHIs.begin(), E = FPHIs.end();
+       I != E; ++I) {
+    PHINode *PH = *I;
+    unsigned IncomingEdge = L->contains(PH->getIncomingBlock(0));
+    unsigned BackEdge     = IncomingEdge^1;
+
+    // Check incoming value.
+    ConstantFP *CZ = dyn_cast<ConstantFP>(PH->getIncomingValue(IncomingEdge));
+    if (!CZ) continue;
+    APFloat PHInit = CZ->getValueAPF();
+    if (!PHInit.isPosZero()) continue;
+
+    // Check IV increment.
+    BinaryOperator *Incr = 
+      dyn_cast<BinaryOperator>(PH->getIncomingValue(BackEdge));
+    if (!Incr) continue;
+    if (Incr->getOpcode() != Instruction::Add) continue;
+    ConstantFP *IncrValue = NULL;
+    unsigned IncrVIndex = 1;
+    if (Incr->getOperand(1) == PH)
+      IncrVIndex = 0;
+    IncrValue = dyn_cast<ConstantFP>(Incr->getOperand(IncrVIndex));
+    if (!IncrValue) continue;
+    APFloat IVAPF = IncrValue->getValueAPF();
+    APFloat One = APFloat(IVAPF.getSemantics(), 1);
+    if (!IVAPF.bitwiseIsEqual(One)) continue;
+
+    // Check Incr uses.
+    Value::use_iterator IncrUse = Incr->use_begin();
+    Instruction *U1 = cast<Instruction>(IncrUse++);
+    if (IncrUse == Incr->use_end()) continue;
+    Instruction *U2 = cast<Instruction>(IncrUse++);
+    if (IncrUse != Incr->use_end()) continue;
+
+    // Find exict condition.
+    FCmpInst *EC = dyn_cast<FCmpInst>(U1);
+    if (!EC)
+      EC = dyn_cast<FCmpInst>(U2);
+    if (!EC) continue;
+    bool skip = false;
+    Instruction *Terminator = EC->getParent()->getTerminator();
+    for(Value::use_iterator ECUI = EC->use_begin(), ECUE = EC->use_end();
+        ECUI != ECUE; ++ECUI) {
+      Instruction *U = cast<Instruction>(ECUI);
+      if (U != Terminator) { 
+        skip = true;
+        break;
+      }
+    }
+    if (skip) continue;
+
+    // Find exit value.
+    ConstantFP *EV = NULL;
+    unsigned EVIndex = 1;
+    if (EC->getOperand(1) == Incr)
+      EVIndex = 0;
+    EV = dyn_cast<ConstantFP>(EC->getOperand(EVIndex));
+    if (!EV) continue;
+    APFloat EVAPF = EV->getValueAPF();
+    if (EVAPF.isNegative()) continue;
+
+    // Find corresponding integer exit value.
+    uint64_t integerVal = Type::Int32Ty->getPrimitiveSizeInBits();
+    bool isExact = false;
+    if (EVAPF.convertToInteger(&integerVal, 32, false, APFloat::rmTowardZero, &isExact)
+        != APFloat::opOK)
+      continue;
+    if (!isExact) continue;
+
+    // Find new predicate for integer comparison.
+    CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE;
+    switch (EC->getPredicate()) {
+    case CmpInst::FCMP_OEQ:
+    case CmpInst::FCMP_UEQ:
+      NewPred = CmpInst::ICMP_EQ;
+      break;
+    case CmpInst::FCMP_OGT:
+    case CmpInst::FCMP_UGT:
+      NewPred = CmpInst::ICMP_UGT;
+      break;
+    case CmpInst::FCMP_OGE:
+    case CmpInst::FCMP_UGE:
+      NewPred = CmpInst::ICMP_UGE;
+      break;
+    case CmpInst::FCMP_OLT:
+    case CmpInst::FCMP_ULT:
+      NewPred = CmpInst::ICMP_ULT;
+      break;
+    case CmpInst::FCMP_OLE:
+    case CmpInst::FCMP_ULE:
+      NewPred = CmpInst::ICMP_ULE;
+      break;
+    default:
+      break;
+    }
+    if (NewPred == CmpInst::BAD_ICMP_PREDICATE) continue;
+
+    // Insert new integer induction variable.
+    SCEVExpander Rewriter(*SE, *LI);
+    PHINode *NewIV = 
+      cast<PHINode>(Rewriter.getOrInsertCanonicalInductionVariable(L,Type::Int32Ty));
+    ConstantInt *NewEV = ConstantInt::get(Type::Int32Ty, integerVal);
+    Value *LHS = (EVIndex == 1 ? NewIV->getIncomingValue(BackEdge) : NewEV);
+    Value *RHS = (EVIndex == 1 ? NewEV : NewIV->getIncomingValue(BackEdge));
+    ICmpInst *NewEC = new ICmpInst(NewPred, LHS, RHS, EC->getNameStart(), 
+                                   EC->getParent()->getTerminator());
+
+    // Delete old, floating point, exit comparision instruction.
+    SE->deleteValueFromRecords(EC);
+    EC->replaceAllUsesWith(NewEC);
+    EC->eraseFromParent();
+
+    // Delete old, floating point, increment instruction.
+    SE->deleteValueFromRecords(Incr);
+    Incr->replaceAllUsesWith(UndefValue::get(Incr->getType()));
+    Incr->eraseFromParent();
+
+    // Replace floating induction variable.
+    UIToFPInst *Conv = new UIToFPInst(NewIV, PH->getType(), "indvar.conv", 
+                                      NonPHIInsn);
+    PH->replaceAllUsesWith(Conv);
+
+    SE->deleteValueFromRecords(PH);
+    PH->removeIncomingValue((unsigned)0);
+    PH->removeIncomingValue((unsigned)0);
+  }
+}
+
diff --git a/test/Transforms/IndVarsSimplify/2008-11-03-Floating.ll b/test/Transforms/IndVarsSimplify/2008-11-03-Floating.ll
new file mode 100644 (file)
index 0000000..b7574fe
--- /dev/null
@@ -0,0 +1,17 @@
+; RUN: llvm-as < %s | opt -indvars | llvm-dis | grep icmp | count 1
+define void @bar() nounwind {
+entry:
+       br label %bb
+
+bb:            ; preds = %bb, %entry
+       %x.0.reg2mem.0 = phi double [ 0.000000e+00, %entry ], [ %1, %bb ]               ; <double> [#uses=2]
+       %0 = tail call i32 @foo(double %x.0.reg2mem.0) nounwind         ; <i32> [#uses=0]
+       %1 = add double %x.0.reg2mem.0, 1.000000e+00            ; <double> [#uses=2]
+       %2 = fcmp olt double %1, 1.000000e+04           ; <i1> [#uses=1]
+       br i1 %2, label %bb, label %return
+
+return:                ; preds = %bb
+       ret void
+}
+
+declare i32 @foo(double)