OSDN Git Service

[SLP] Try a bit harder to find reduction PHIs
authorCharlie Turner <charlie.turner@arm.com>
Tue, 27 Oct 2015 17:54:16 +0000 (17:54 +0000)
committerCharlie Turner <charlie.turner@arm.com>
Tue, 27 Oct 2015 17:54:16 +0000 (17:54 +0000)
Summary:
Currently, when the SLP vectorizer considers whether a phi is part of a reduction, it dismisses phi's whose incoming blocks are not the same as the block containing the phi. For the patterns I'm looking at, extending this rule to allow phis whose incoming block is a containing loop latch allows me to vectorize certain workloads.

There is no significant compile-time impact, and combined with D13949, no performance improvement measured in ARM/AArch64 in any of SPEC2000, SPEC2006 or LNT.

Reviewers: jmolloy, mcrosier, nadav

Subscribers: mssimpso, nadav, aemerson, llvm-commits

Differential Revision: http://reviews.llvm.org/D14063

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@251425 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Vectorize/SLPVectorizer.cpp
test/Transforms/SLPVectorizer/AArch64/horizontal.ll

index 649f2cf..dcaa240 100644 (file)
@@ -3933,6 +3933,46 @@ static bool PhiTypeSorterFunc(Value *V, Value *V2) {
   return V->getType() < V2->getType();
 }
 
+/// \brief Try and get a reduction value from a phi node.
+///
+/// Given a phi node \p P in a block \p ParentBB, consider possible reductions
+/// if they come from either \p ParentBB or a containing loop latch.
+///
+/// \returns A candidate reduction value if possible, or \code nullptr \endcode
+/// if not possible.
+static Value *getReductionValue(PHINode *P, BasicBlock *ParentBB,
+                                LoopInfo *LI) {
+  Value *Rdx = nullptr;
+
+  // Return the incoming value if it comes from the same BB as the phi node.
+  if (P->getIncomingBlock(0) == ParentBB) {
+    Rdx = P->getIncomingValue(0);
+  } else if (P->getIncomingBlock(1) == ParentBB) {
+    Rdx = P->getIncomingValue(1);
+  }
+
+  if (Rdx)
+    return Rdx;
+
+  // Otherwise, check whether we have a loop latch to look at.
+  Loop *BBL = LI->getLoopFor(ParentBB);
+  if (!BBL)
+    return Rdx;
+  BasicBlock *BBLatch = BBL->getLoopLatch();
+  if (!BBLatch)
+    return Rdx;
+
+  // There is a loop latch, return the incoming value if it comes from
+  // that. This reduction pattern occassionaly turns up.
+  if (P->getIncomingBlock(0) == BBLatch) {
+    Rdx = P->getIncomingValue(0);
+  } else if (P->getIncomingBlock(1) == BBLatch) {
+    Rdx = P->getIncomingValue(1);
+  }
+
+  return Rdx;
+}
+
 bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
   bool Changed = false;
   SmallVector<Value *, 4> Incoming;
@@ -4000,11 +4040,9 @@ bool SLPVectorizer::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {
       // Check that the PHI is a reduction PHI.
       if (P->getNumIncomingValues() != 2)
         return Changed;
-      Value *Rdx =
-          (P->getIncomingBlock(0) == BB
-               ? (P->getIncomingValue(0))
-               : (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1)
-                                               : nullptr));
+
+      Value *Rdx = getReductionValue(P, BB, LI);
+
       // Check if this is a Binary Operator.
       BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
       if (!BI)
index dca7bda..80ab421 100644 (file)
@@ -71,3 +71,77 @@ for.end:                                          ; preds = %for.end.loopexit, %
   %s.0.lcssa = phi i32 [ 0, %entry ], [ %add27, %for.end.loopexit ]
   ret i32 %s.0.lcssa
 }
+
+;; Check whether SLP can find a reduction phi whose incoming blocks are not
+;; the same as the block containing the phi.
+;;
+;; Came from code like,
+;;
+;; int s = 0;
+;; for (int j = 0; j < h; j++) {
+;;   s += p1[0] * p2[0]
+;;   s += p1[1] * p2[1];
+;;   s += p1[2] * p2[2];
+;;   s += p1[3] * p2[3];
+;;   if (s >= lim)
+;;      break;
+;;   p1 += lx;
+;;   p2 += lx;
+;; }
+define i32 @reduction_with_br(i32* noalias nocapture readonly %blk1, i32* noalias nocapture readonly %blk2, i32 %lx, i32 %h, i32 %lim) {
+; CHECK-LABEL: reduction_with_br
+; CHECK: load <4 x i32>
+; CHECK: load <4 x i32>
+; CHECK: mul nsw <4 x i32>
+entry:
+  %cmp.16 = icmp sgt i32 %h, 0
+  br i1 %cmp.16, label %for.body.lr.ph, label %for.end
+
+for.body.lr.ph:                                   ; preds = %entry
+  %idx.ext = sext i32 %lx to i64
+  br label %for.body
+
+for.body:                                         ; preds = %for.body.lr.ph, %if.end
+  %s.020 = phi i32 [ 0, %for.body.lr.ph ], [ %add13, %if.end ]
+  %j.019 = phi i32 [ 0, %for.body.lr.ph ], [ %inc, %if.end ]
+  %p2.018 = phi i32* [ %blk2, %for.body.lr.ph ], [ %add.ptr16, %if.end ]
+  %p1.017 = phi i32* [ %blk1, %for.body.lr.ph ], [ %add.ptr, %if.end ]
+  %0 = load i32, i32* %p1.017, align 4
+  %1 = load i32, i32* %p2.018, align 4
+  %mul = mul nsw i32 %1, %0
+  %add = add nsw i32 %mul, %s.020
+  %arrayidx2 = getelementptr inbounds i32, i32* %p1.017, i64 1
+  %2 = load i32, i32* %arrayidx2, align 4
+  %arrayidx3 = getelementptr inbounds i32, i32* %p2.018, i64 1
+  %3 = load i32, i32* %arrayidx3, align 4
+  %mul4 = mul nsw i32 %3, %2
+  %add5 = add nsw i32 %add, %mul4
+  %arrayidx6 = getelementptr inbounds i32, i32* %p1.017, i64 2
+  %4 = load i32, i32* %arrayidx6, align 4
+  %arrayidx7 = getelementptr inbounds i32, i32* %p2.018, i64 2
+  %5 = load i32, i32* %arrayidx7, align 4
+  %mul8 = mul nsw i32 %5, %4
+  %add9 = add nsw i32 %add5, %mul8
+  %arrayidx10 = getelementptr inbounds i32, i32* %p1.017, i64 3
+  %6 = load i32, i32* %arrayidx10, align 4
+  %arrayidx11 = getelementptr inbounds i32, i32* %p2.018, i64 3
+  %7 = load i32, i32* %arrayidx11, align 4
+  %mul12 = mul nsw i32 %7, %6
+  %add13 = add nsw i32 %add9, %mul12
+  %cmp14 = icmp slt i32 %add13, %lim
+  br i1 %cmp14, label %if.end, label %for.end.loopexit
+
+if.end:                                           ; preds = %for.body
+  %add.ptr = getelementptr inbounds i32, i32* %p1.017, i64 %idx.ext
+  %add.ptr16 = getelementptr inbounds i32, i32* %p2.018, i64 %idx.ext
+  %inc = add nuw nsw i32 %j.019, 1
+  %cmp = icmp slt i32 %inc, %h
+  br i1 %cmp, label %for.body, label %for.end.loopexit
+
+for.end.loopexit:                                 ; preds = %for.body, %if.end
+  br label %for.end
+
+for.end:                                          ; preds = %for.end.loopexit, %entry
+  %s.1 = phi i32 [ 0, %entry ], [ %add13, %for.end.loopexit ]
+  ret i32 %s.1
+}