OSDN Git Service

[DAGCombiner] Better support for shifting large value type by constants
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 9 Aug 2016 17:39:11 +0000 (17:39 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 9 Aug 2016 17:39:11 +0000 (17:39 +0000)
As detailed on D22726, much of the shift combining code assume constant values will fit into a uint64_t value and calls ConstantSDNode::getZExtValue where it probably shouldn't (leading to asserts). Using APInt directly avoids this problem but we encounter other assertions if we attempt to compare/operate on 2 APInt of different bitwidths.

This patch adds a helper function to ensure that 2 APInt values are zero extended as required so that they can be safely used together. I've only added an initial example use for this to the '(SHIFT (SHIFT x, c1), c2) --> (SHIFT x, (ADD c1, c2))' combines. Further cases can easily be added as required.

Differential Revision: https://reviews.llvm.org/D23007

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@278141 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/X86/shift-i128.ll

index 5bcea64..b32737f 100644 (file)
@@ -726,6 +726,15 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG,
   }
 }
 
+// APInts must be the same size for most operations, this helper
+// function zero extends the shorter of the pair so that they match.
+// We provide an Offset so that we can create bitwidths that won't overflow.
+static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) {
+  unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth());
+  LHS = LHS.zextOrSelf(Bits);
+  RHS = RHS.zextOrSelf(Bits);
+}
+
 // Return true if this node is a setcc, or is a select_cc
 // that selects between the target values used for true and false, making it
 // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to
@@ -4464,13 +4473,18 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
   if (N1C && N0.getOpcode() == ISD::SHL) {
     if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
-      uint64_t c1 = N0C1->getZExtValue();
-      uint64_t c2 = N1C->getZExtValue();
       SDLoc DL(N);
-      if (c1 + c2 >= OpSizeInBits)
+      APInt c1 = N0C1->getAPIntValue();
+      APInt c2 = N1C->getAPIntValue();
+      zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+
+      APInt Sum = c1 + c2;
+      if (Sum.uge(OpSizeInBits))
         return DAG.getConstant(0, DL, VT);
-      return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0),
-                         DAG.getConstant(c1 + c2, DL, N1.getValueType()));
+
+      return DAG.getNode(
+          ISD::SHL, DL, VT, N0.getOperand(0),
+          DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
     }
   }
 
@@ -4656,13 +4670,19 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
 
   // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2))
   if (N1C && N0.getOpcode() == ISD::SRA) {
-    if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) {
-      unsigned Sum = N1C->getZExtValue() + C1->getZExtValue();
-      if (Sum >= OpSizeInBits)
-        Sum = OpSizeInBits - 1;
+    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
       SDLoc DL(N);
-      return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0),
-                         DAG.getConstant(Sum, DL, N1.getValueType()));
+      APInt c1 = N0C1->getAPIntValue();
+      APInt c2 = N1C->getAPIntValue();
+      zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+
+      APInt Sum = c1 + c2;
+      if (Sum.uge(OpSizeInBits))
+        Sum = APInt(OpSizeInBits, OpSizeInBits - 1);
+
+      return DAG.getNode(
+          ISD::SRA, DL, VT, N0.getOperand(0),
+          DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
     }
   }
 
@@ -4790,14 +4810,19 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
 
   // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2))
   if (N1C && N0.getOpcode() == ISD::SRL) {
-    if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) {
-      uint64_t c1 = N01C->getZExtValue();
-      uint64_t c2 = N1C->getZExtValue();
+    if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) {
       SDLoc DL(N);
-      if (c1 + c2 >= OpSizeInBits)
+      APInt c1 = N0C1->getAPIntValue();
+      APInt c2 = N1C->getAPIntValue();
+      zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */);
+
+      APInt Sum = c1 + c2;
+      if (Sum.uge(OpSizeInBits))
         return DAG.getConstant(0, DL, VT);
-      return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0),
-                         DAG.getConstant(c1 + c2, DL, N1.getValueType()));
+
+      return DAG.getNode(
+          ISD::SRL, DL, VT, N0.getOperand(0),
+          DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType()));
     }
   }
 
index f7dce6c..aef923f 100644 (file)
@@ -92,3 +92,27 @@ entry:
        store <2 x i128> %0, <2 x i128>* %r, align 16
        ret void
 }
+
+define void @test_lshr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
+entry:
+       %0 = lshr <2 x i128> %x, <i128 -1, i128 -1>
+       %1 = lshr <2 x i128> %0, <i128  1, i128  1>
+       store <2 x i128> %1, <2 x i128>* %r, align 16
+       ret void
+}
+
+define void @test_ashr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
+entry:
+       %0 = ashr <2 x i128> %x, <i128 -1, i128 -1>
+       %1 = ashr <2 x i128> %0, <i128  1, i128  1>
+       store <2 x i128> %1, <2 x i128>* %r, align 16
+       ret void
+}
+
+define void @test_shl_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind {
+entry:
+       %0 = shl <2 x i128> %x, <i128 -1, i128 -1>
+       %1 = shl <2 x i128> %0, <i128  1, i128  1>
+       store <2 x i128> %1, <2 x i128>* %r, align 16
+       ret void
+}