OSDN Git Service

[LIR] Teach LIR to avoid extending the BE count prior to adding one to
authorChandler Carruth <chandlerc@gmail.com>
Tue, 25 Jul 2017 10:48:32 +0000 (10:48 +0000)
committerChandler Carruth <chandlerc@gmail.com>
Tue, 25 Jul 2017 10:48:32 +0000 (10:48 +0000)
it when safe.

Very often the BE count is the trip count minus one, and the plus one
here should fold with that minus one. But because the BE count might in
theory be UINT_MAX or some such, adding one before we extend could in
some cases wrap to zero and break when we scale things.

This patch checks to see if it would be safe to add one because the
specific case that would cause this is guarded for prior to entering the
preheader. This should handle essentially all of the common loop idioms
coming out of C/C++ code once canonicalized by LLVM.

Before this patch, both forms of loop in the added test cases ended up
subtracting one from the size, extending it, scaling it up by 8 and then
adding 8 back onto it. This is really silly, and it turns out made it
all the way into generated code very often, so this is a surprisingly
important cleanup to do.

Many thanks to Sanjoy for showing me how to do this with SCEV.

Differential Revision: https://reviews.llvm.org/D35758

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@308968 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoopIdiomRecognize.cpp
test/Transforms/LoopIdiom/basic.ll

index 4a6a35c..9051b7c 100644 (file)
@@ -780,6 +780,41 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
   return SE->getMinusSCEV(Start, Index);
 }
 
+/// Compute the number of bytes as a SCEV from the backedge taken count.
+///
+/// This also maps the SCEV into the provided type and tries to handle the
+/// computation in a way that will fold cleanly.
+static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
+                               unsigned StoreSize, Loop *CurLoop,
+                               const DataLayout *DL, ScalarEvolution *SE) {
+  const SCEV *NumBytesS;
+  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
+  // pointer size if it isn't already.
+  //
+  // If we're going to need to zero extend the BE count, check if we can add
+  // one to it prior to zero extending without overflow. Provided this is safe,
+  // it allows better simplification of the +1.
+  if (DL->getTypeSizeInBits(BECount->getType()) <
+          DL->getTypeSizeInBits(IntPtr) &&
+      SE->isLoopEntryGuardedByCond(
+          CurLoop, ICmpInst::ICMP_NE, BECount,
+          SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
+    NumBytesS = SE->getZeroExtendExpr(
+        SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
+        IntPtr);
+  } else {
+    NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
+                               SE->getOne(IntPtr), SCEV::FlagNUW);
+  }
+
+  // And scale it based on the store size.
+  if (StoreSize != 1) {
+    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
+                               SCEV::FlagNUW);
+  }
+  return NumBytesS;
+}
+
 /// processLoopStridedStore - We see a strided store of some value.  If we can
 /// transform this into a memset or memset_pattern in the loop preheader, do so.
 bool LoopIdiomRecognize::processLoopStridedStore(
@@ -837,16 +872,8 @@ bool LoopIdiomRecognize::processLoopStridedStore(
 
   // Okay, everything looks good, insert the memset.
 
-  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
-  // pointer size if it isn't already.
-  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr);
-
   const SCEV *NumBytesS =
-      SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW);
-  if (StoreSize != 1) {
-    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
-                               SCEV::FlagNUW);
-  }
+      getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE);
 
   // TODO: ideally we should still be able to generate memset if SCEV expander
   // is taught to generate the dependencies at the latest point.
@@ -976,16 +1003,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,
 
   // Okay, everything is safe, we can transform this!
 
-  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
-  // pointer size if it isn't already.
-  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);
-
   const SCEV *NumBytesS =
-      SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);
-
-  if (StoreSize != 1)
-    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
-                               SCEV::FlagNUW);
+      getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE);
 
   Value *NumBytes =
       Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator());
index 270de2e..ba3e8a0 100644 (file)
@@ -563,6 +563,75 @@ for.end6:                                         ; preds = %for.inc4
 ; CHECK: ret void
 }
 
+; Handle loops where the trip count is a narrow integer that needs to be
+; extended.
+define void @form_memset_narrow_size(i64* %ptr, i32 %size) {
+; CHECK-LABEL: @form_memset_narrow_size(
+entry:
+  %cmp1 = icmp sgt i32 %size, 0
+  br i1 %cmp1, label %loop.ph, label %exit
+; CHECK:       entry:
+; CHECK:         %[[C1:.*]] = icmp sgt i32 %size, 0
+; CHECK-NEXT:    br i1 %[[C1]], label %loop.ph, label %exit
+
+loop.ph:
+  br label %loop.body
+; CHECK:       loop.ph:
+; CHECK-NEXT:    %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
+; CHECK-NEXT:    %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* %{{.*}}, i8 0, i64 %[[SCALED_SIZE]], i32 8, i1 false)
+
+loop.body:
+  %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
+  %idxprom = sext i32 %storemerge4 to i64
+  %arrayidx = getelementptr inbounds i64, i64* %ptr, i64 %idxprom
+  store i64 0, i64* %arrayidx, align 8
+  %inc = add nsw i32 %storemerge4, 1
+  %cmp2 = icmp slt i32 %inc, %size
+  br i1 %cmp2, label %loop.body, label %loop.exit
+
+loop.exit:
+  br label %exit
+
+exit:
+  ret void
+}
+
+define void @form_memcpy_narrow_size(i64* noalias %dst, i64* noalias %src, i32 %size) {
+; CHECK-LABEL: @form_memcpy_narrow_size(
+entry:
+  %cmp1 = icmp sgt i32 %size, 0
+  br i1 %cmp1, label %loop.ph, label %exit
+; CHECK:       entry:
+; CHECK:         %[[C1:.*]] = icmp sgt i32 %size, 0
+; CHECK-NEXT:    br i1 %[[C1]], label %loop.ph, label %exit
+
+loop.ph:
+  br label %loop.body
+; CHECK:       loop.ph:
+; CHECK-NEXT:    %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
+; CHECK-NEXT:    %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
+; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 %[[SCALED_SIZE]], i32 8, i1 false)
+
+loop.body:
+  %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
+  %idxprom1 = sext i32 %storemerge4 to i64
+  %arrayidx1 = getelementptr inbounds i64, i64* %src, i64 %idxprom1
+  %v = load i64, i64* %arrayidx1, align 8
+  %idxprom2 = sext i32 %storemerge4 to i64
+  %arrayidx2 = getelementptr inbounds i64, i64* %dst, i64 %idxprom2
+  store i64 %v, i64* %arrayidx2, align 8
+  %inc = add nsw i32 %storemerge4, 1
+  %cmp2 = icmp slt i32 %inc, %size
+  br i1 %cmp2, label %loop.body, label %loop.exit
+
+loop.exit:
+  br label %exit
+
+exit:
+  ret void
+}
+
 ; Validate that "memset_pattern" has the proper attributes.
 ; CHECK: declare void @memset_pattern16(i8* nocapture, i8* nocapture readonly, i64) [[ATTRS:#[0-9]+]]
 ; CHECK: [[ATTRS]] = { argmemonly }