From dfe81adbcebb61d762d5c72771741dee7dcc1c4c Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Mon, 13 Oct 2014 21:48:30 +0000 Subject: [PATCH] InstCombine: Don't miscompile (x lshr C1) udiv C2 We have a transform that changes: (x lshr C1) udiv C2 into: x udiv (C2 << C1) However, it is unsafe to do so if C2 << C1 discards any of C2's bits. This fixes PR21255. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@219634 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/ADT/APInt.h | 3 ++- lib/Support/APInt.cpp | 20 +++++++++++++++----- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 14 ++++++++++---- test/Transforms/InstCombine/div.ll | 16 +++++++++++++--- 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/include/llvm/ADT/APInt.h b/include/llvm/ADT/APInt.h index f815628f30c..4d19bab13f4 100644 --- a/include/llvm/ADT/APInt.h +++ b/include/llvm/ADT/APInt.h @@ -945,7 +945,8 @@ public: APInt sdiv_ov(const APInt &RHS, bool &Overflow) const; APInt smul_ov(const APInt &RHS, bool &Overflow) const; APInt umul_ov(const APInt &RHS, bool &Overflow) const; - APInt sshl_ov(unsigned Amt, bool &Overflow) const; + APInt sshl_ov(const APInt &Amt, bool &Overflow) const; + APInt ushl_ov(const APInt &Amt, bool &Overflow) const; /// \brief Array-indexing support. /// diff --git a/lib/Support/APInt.cpp b/lib/Support/APInt.cpp index 02778b2fc7c..c20eeb26948 100644 --- a/lib/Support/APInt.cpp +++ b/lib/Support/APInt.cpp @@ -2064,19 +2064,29 @@ APInt APInt::umul_ov(const APInt &RHS, bool &Overflow) const { return Res; } -APInt APInt::sshl_ov(unsigned ShAmt, bool &Overflow) const { - Overflow = ShAmt >= getBitWidth(); +APInt APInt::sshl_ov(const APInt &ShAmt, bool &Overflow) const { + Overflow = ShAmt.uge(getBitWidth()); if (Overflow) - ShAmt = getBitWidth()-1; + return APInt(BitWidth, 0); if (isNonNegative()) // Don't allow sign change. - Overflow = ShAmt >= countLeadingZeros(); + Overflow = ShAmt.uge(countLeadingZeros()); else - Overflow = ShAmt >= countLeadingOnes(); + Overflow = ShAmt.uge(countLeadingOnes()); return *this << ShAmt; } +APInt APInt::ushl_ov(const APInt &ShAmt, bool &Overflow) const { + Overflow = ShAmt.uge(getBitWidth()); + if (Overflow) + return APInt(BitWidth, 0); + + Overflow = ShAmt.ugt(countLeadingZeros()); + + return *this << ShAmt; +} + diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index dad2c2d256d..846a3640930 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -965,11 +965,17 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { return Common; // (x lshr C1) udiv C2 --> x udiv (C2 << C1) - if (Constant *C2 = dyn_cast(Op1)) { + { Value *X; - Constant *C1; - if (match(Op0, m_LShr(m_Value(X), m_Constant(C1)))) - return BinaryOperator::CreateUDiv(X, ConstantExpr::getShl(C2, C1)); + const APInt *C1, *C2; + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1))) && + match(Op1, m_APInt(C2))) { + bool Overflow; + APInt C2ShlC1 = C2->ushl_ov(*C1, Overflow); + if (!Overflow) + return BinaryOperator::CreateUDiv( + X, ConstantInt::get(X->getType(), C2ShlC1)); + } } // (zext A) udiv (zext B) --> zext (A udiv B) diff --git a/test/Transforms/InstCombine/div.ll b/test/Transforms/InstCombine/div.ll index 5a884ac671d..f2a70fd0f0d 100644 --- a/test/Transforms/InstCombine/div.ll +++ b/test/Transforms/InstCombine/div.ll @@ -132,11 +132,11 @@ define i32 @test15(i32 %a, i32 %b) nounwind { } define <2 x i64> @test16(<2 x i64> %x) nounwind { - %shr = lshr <2 x i64> %x, - %div = udiv <2 x i64> %shr, + %shr = lshr <2 x i64> %x, + %div = udiv <2 x i64> %shr, ret <2 x i64> %div ; CHECK-LABEL: @test16( -; CHECK-NEXT: udiv <2 x i64> %x, +; CHECK-NEXT: udiv <2 x i64> %x, ; CHECK-NEXT: ret <2 x i64> } @@ -264,3 +264,13 @@ define i32 @test30(i32 %a) { ; CHECK-LABEL: @test30( ; CHECK-NEXT: ret i32 %a } + +define <2 x i32> @test31(<2 x i32> %x) nounwind { + %shr = lshr <2 x i32> %x, + %div = udiv <2 x i32> %shr, + ret <2 x i32> %div +; CHECK-LABEL: @test31( +; CHECK-NEXT: %[[shr:.*]] = lshr <2 x i32> %x, +; CHECK-NEXT: udiv <2 x i32> %[[shr]], +; CHECK-NEXT: ret <2 x i32> +} -- 2.11.0