X-Git-Url: http://git.osdn.net/view?a=blobdiff_plain;f=lib%2FAnalysis%2FInstructionSimplify.cpp;h=6ebae37a6a8ed1a11845c82a58e95c0eb018beac;hb=f241788fa53efc925312f5eb1b2630f76675e2f5;hp=78ae0abf2a1545ba79e5236a319b1aea94ac51d0;hpb=3b8950a6d6b0d0bdb54904dda2cbad9ce16db0cc;p=android-x86%2Fexternal-llvm.git diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 78ae0abf2a1..6ebae37a6a8 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -27,7 +27,6 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -63,6 +62,50 @@ static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyCastInst(unsigned, Value *, Type *, const SimplifyQuery &, unsigned); +static Value *SimplifyGEPInst(Type *, ArrayRef, const SimplifyQuery &, + unsigned); + +static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, + Value *FalseVal) { + BinaryOperator::BinaryOps BinOpCode; + if (auto *BO = dyn_cast(Cond)) + BinOpCode = BO->getOpcode(); + else + return nullptr; + + CmpInst::Predicate ExpectedPred, Pred1, Pred2; + if (BinOpCode == BinaryOperator::Or) { + ExpectedPred = ICmpInst::ICMP_NE; + } else if (BinOpCode == BinaryOperator::And) { + ExpectedPred = ICmpInst::ICMP_EQ; + } else + return nullptr; + + // %A = icmp eq %TV, %FV + // %B = icmp eq %X, %Y (and one of these is a select operand) + // %C = and %A, %B + // %D = select %C, %TV, %FV + // --> + // %FV + + // %A = icmp ne %TV, %FV + // %B = icmp ne %X, %Y (and one of these is a select operand) + // %C = or %A, %B + // %D = select %C, %TV, %FV + // --> + // %TV + Value *X, *Y; + if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal), + m_Specific(FalseVal)), + m_ICmp(Pred2, m_Value(X), m_Value(Y)))) || + Pred1 != Pred2 || Pred1 != ExpectedPred) + return nullptr; + + if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal) + return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal; + + return nullptr; +} /// For a boolean type or a vector of boolean type, return false or a vector /// with every element false. @@ -91,7 +134,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, } /// Does the given value dominate the specified phi node? -static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { +static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { Instruction *I = dyn_cast(V); if (!I) // Arguments and constants dominate all instructions. @@ -100,7 +143,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // If we are processing instructions (and/or basic blocks) that have not been // fully added to a function, the parent nodes may still be null. Simply // return the conservative answer in these cases. - if (!I->getParent() || !P->getParent() || !I->getParent()->getParent()) + if (!I->getParent() || !P->getParent() || !I->getFunction()) return false; // If we have a DominatorTree then do a precise test. @@ -109,7 +152,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { // Otherwise, if the instruction is in the entry block and is not an invoke, // then it obviously dominates all phi nodes. - if (I->getParent() == &I->getParent()->getParent()->getEntryBlock() && + if (I->getParent() == &I->getFunction()->getEntryBlock() && !isa(I)) return true; @@ -328,7 +371,7 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, // Check that the simplified value has the form "X op Y" where "op" is the // same as the original operation. Instruction *Simplified = dyn_cast(FV ? FV : TV); - if (Simplified && Simplified->getOpcode() == Opcode) { + if (Simplified && Simplified->getOpcode() == unsigned(Opcode)) { // The value that didn't simplify is "UnsimplifiedLHS op UnsimplifiedRHS". // We already know that "op" is the same as for the simplified value. See // if the operands match too. If so, return the simplified value. @@ -444,13 +487,13 @@ static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, if (isa(LHS)) { PI = cast(LHS); // Bail out if RHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(RHS, PI, Q.DT)) + if (!valueDominatesPHI(RHS, PI, Q.DT)) return nullptr; } else { assert(isa(RHS) && "No PHI instruction operand!"); PI = cast(RHS); // Bail out if LHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(LHS, PI, Q.DT)) + if (!valueDominatesPHI(LHS, PI, Q.DT)) return nullptr; } @@ -491,7 +534,7 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, PHINode *PI = cast(LHS); // Bail out if RHS and the phi may be mutually interdependent due to a loop. - if (!ValueDominatesPHI(RHS, PI, Q.DT)) + if (!valueDominatesPHI(RHS, PI, Q.DT)) return nullptr; // Evaluate the BinOp on the incoming phi values. @@ -526,7 +569,7 @@ static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, /// Given operands for an Add, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Q, unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -539,6 +582,10 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (match(Op1, m_Zero())) return Op0; + // If two operands are negative, return 0. + if (isKnownNegation(Op0, Op1)) + return Constant::getNullValue(Op0->getType()); + // X + (Y - X) -> Y // (Y - X) + X -> Y // Eg: X + -X -> 0 @@ -556,10 +603,14 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // add nsw/nuw (xor Y, signmask), signmask --> Y // The no-wrapping add guarantees that the top bit will be set by the add. // Therefore, the xor must be clearing the already set sign bit of Y. - if ((isNSW || isNUW) && match(Op1, m_SignMask()) && + if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && match(Op0, m_Xor(m_Value(Y), m_SignMask()))) return Y; + // add nuw %x, -1 -> -1, because %x can only be 0. + if (IsNUW && match(Op1, m_AllOnes())) + return Op1; // Which is -1. + /// i1 add -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) @@ -582,12 +633,12 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return nullptr; } -Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, const SimplifyQuery &Query) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query, RecursionLimit); + return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); } -/// \brief Compute the base pointer and cumulative constant offsets for V. +/// Compute the base pointer and cumulative constant offsets for V. /// /// This strips all constant offsets off of V, leaving it the base pointer, and /// accumulates the total constant offset applied in the returned constant. It @@ -638,7 +689,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, return OffsetIntPtr; } -/// \brief Compute the constant difference between two pointer values. +/// Compute the constant difference between two pointer values. /// If the difference is not a constant, returns zero. static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, Value *RHS) { @@ -681,14 +732,14 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (match(Op0, m_Zero())) { // 0 - X -> 0 if the sub is NUW. if (isNUW) - return Op0; + return Constant::getNullValue(Op0->getType()); KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); if (Known.Zero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. if (isNSW) - return Op0; + return Constant::getNullValue(Op0->getType()); // 0 - X -> X if X is 0 or the minimum signed value. return Op1; @@ -800,12 +851,9 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X * undef -> 0 - if (match(Op1, m_Undef())) - return Constant::getNullValue(Op0->getType()); - // X * 0 -> 0 - if (match(Op1, m_Zero())) - return Op1; + if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) + return Constant::getNullValue(Op0->getType()); // X * 1 -> X if (match(Op1, m_One())) @@ -827,7 +875,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, MaxRecurse)) return V; - // Mul distributes over Add. Try some generic simplifications based on this. + // Mul distributes over Add. Try some generic simplifications based on this. if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, Q, MaxRecurse)) return V; @@ -869,13 +917,14 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { if (match(Op1, m_Zero())) return UndefValue::get(Ty); - // If any element of a constant divisor vector is zero, the whole op is undef. + // If any element of a constant divisor vector is zero or undef, the whole op + // is undef. auto *Op1C = dyn_cast(Op1); if (Op1C && Ty->isVectorTy()) { unsigned NumElts = Ty->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = Op1C->getAggregateElement(i); - if (Elt && Elt->isNullValue()) + if (Elt && (Elt->isNullValue() || isa(Elt))) return UndefValue::get(Ty); } } @@ -888,7 +937,7 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { // 0 / X -> 0 // 0 % X -> 0 if (match(Op0, m_Zero())) - return Op0; + return Constant::getNullValue(Op0->getType()); // X / X -> 1 // X % X -> 0 @@ -899,7 +948,10 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { // X % 1 -> 0 // If this is a boolean op (single-bit element type), we can't have // division-by-zero or remainder-by-zero, so assume the divisor is 1. - if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1)) + // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. + Value *X; + if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || + (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return IsDiv ? Op0 : Constant::getNullValue(Ty); return nullptr; @@ -979,18 +1031,17 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool IsSigned = Opcode == Instruction::SDiv; // (X * Y) / Y -> X if the multiplication does not overflow. - Value *X = nullptr, *Y = nullptr; - if (match(Op0, m_Mul(m_Value(X), m_Value(Y))) && (X == Op1 || Y == Op1)) { - if (Y != Op1) std::swap(X, Y); // Ensure expression is (X * Y) / Y, Y = Op1 - OverflowingBinaryOperator *Mul = cast(Op0); - // If the Mul knows it does not overflow, then we are good to go. + Value *X; + if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + auto *Mul = cast(Op0); + // If the Mul does not overflow, then we are good to go. if ((IsSigned && Mul->hasNoSignedWrap()) || (!IsSigned && Mul->hasNoUnsignedWrap())) return X; - // If X has the form X = A / Y then X * Y cannot overflow. - if (BinaryOperator *Div = dyn_cast(X)) - if (Div->getOpcode() == Opcode && Div->getOperand(1) == Y) - return X; + // If X has the form X = A / Y, then X * Y cannot overflow. + if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) + return X; } // (X rem Y) / Y -> 0 @@ -1042,6 +1093,13 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, match(Op0, m_URem(m_Value(), m_Specific(Op1))))) return Op0; + // (X << Y) % X -> 0 + if ((Opcode == Instruction::SRem && + match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || + (Opcode == Instruction::URem && + match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))) + return Constant::getNullValue(Op0->getType()); + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa(Op0) || isa(Op1)) @@ -1065,6 +1123,10 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, /// If not, this returns null. static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + // If two operands are negated and no signed overflow, return -1. + if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) + return Constant::getAllOnesValue(Op0->getType()); + return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); } @@ -1087,6 +1149,16 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { /// If not, this returns null. static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + // If the divisor is 0, the result is undefined, so assume the divisor is -1. + // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 + Value *X; + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) + return ConstantInt::getNullValue(Op0->getType()); + + // If the two operands are negated, return 0. + if (isKnownNegation(Op0, Op1)) + return ConstantInt::getNullValue(Op0->getType()); + return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse); } @@ -1141,10 +1213,14 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // 0 shift by X -> 0 if (match(Op0, m_Zero())) - return Op0; + return Constant::getNullValue(Op0->getType()); // X shift by 0 -> X - if (match(Op1, m_Zero())) + // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones + // would be poison. + Value *X; + if (match(Op1, m_Zero()) || + (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return Op0; // Fold undefined shifts. @@ -1178,7 +1254,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, return nullptr; } -/// \brief Given operands for an Shl, LShr or AShr, see if we can +/// Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, bool isExact, const SimplifyQuery &Q, @@ -1221,6 +1297,13 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *X; if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) return X; + + // shl nuw i8 C, %x -> C iff C has sign bit set. + if (isNUW && match(Op0, m_Negative())) + return Op0; + // NOTE: could use computeKnownBits() / LazyValueInfo, + // but the cost-benefit analysis suggests it isn't worth it. + return nullptr; } @@ -1258,9 +1341,10 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, MaxRecurse)) return V; - // all ones >>a X -> all ones + // all ones >>a X -> -1 + // Do not return Op0 because it may contain undef elements if it's a vector. if (match(Op0, m_AllOnes())) - return Op0; + return Constant::getAllOnesValue(Op0->getType()); // (X << A) >> A -> X Value *X; @@ -1296,7 +1380,7 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, ICmpInst::isUnsigned(UnsignedPred)) ; else if (match(UnsignedICmp, - m_ICmp(UnsignedPred, m_Value(Y), m_Specific(X))) && + m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) && ICmpInst::isUnsigned(UnsignedPred)) UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); else @@ -1414,6 +1498,43 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; } +static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { + ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); + if (!match(Cmp0->getOperand(1), m_Zero()) || + !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) + return nullptr; + + if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) + return nullptr; + + // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". + Value *X = Cmp0->getOperand(0); + Value *Y = Cmp1->getOperand(0); + + // If one of the compares is a masked version of a (not) null check, then + // that compare implies the other, so we eliminate the other. Optionally, look + // through a pointer-to-int cast to match a null check of a pointer type. + + // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 + // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 + // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 + // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 + if (match(Y, m_c_And(m_Specific(X), m_Value())) || + match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) + return Cmp1; + + // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 + // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 + // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 + // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 + if (match(X, m_c_And(m_Specific(Y), m_Value())) || + match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) + return Cmp0; + + return nullptr; +} + static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { // (icmp (add V, C0), C1) & (icmp V, C0) ICmpInst::Predicate Pred0, Pred1; @@ -1474,6 +1595,9 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) return X; + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) + return X; + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) return X; if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) @@ -1542,6 +1666,9 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) return X; + if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) + return X; + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) return X; if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) @@ -1550,7 +1677,44 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; } -static Value *simplifyAndOrOfICmps(Value *Op0, Value *Op1, bool IsAnd) { +static Value *simplifyAndOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + if (LHS0->getType() != RHS0->getType()) + return nullptr; + + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { + // (fcmp ord NNAN, X) & (fcmp ord X, Y) --> fcmp ord X, Y + // (fcmp ord NNAN, X) & (fcmp ord Y, X) --> fcmp ord Y, X + // (fcmp ord X, NNAN) & (fcmp ord X, Y) --> fcmp ord X, Y + // (fcmp ord X, NNAN) & (fcmp ord Y, X) --> fcmp ord Y, X + // (fcmp uno NNAN, X) | (fcmp uno X, Y) --> fcmp uno X, Y + // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X + // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y + // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X + if ((isKnownNeverNaN(LHS0) && (LHS1 == RHS0 || LHS1 == RHS1)) || + (isKnownNeverNaN(LHS1) && (LHS0 == RHS0 || LHS0 == RHS1))) + return RHS; + + // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y + // (fcmp ord Y, X) & (fcmp ord NNAN, X) --> fcmp ord Y, X + // (fcmp ord X, Y) & (fcmp ord X, NNAN) --> fcmp ord X, Y + // (fcmp ord Y, X) & (fcmp ord X, NNAN) --> fcmp ord Y, X + // (fcmp uno X, Y) | (fcmp uno NNAN, X) --> fcmp uno X, Y + // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X + // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y + // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X + if ((isKnownNeverNaN(RHS0) && (RHS1 == LHS0 || RHS1 == LHS1)) || + (isKnownNeverNaN(RHS1) && (RHS0 == LHS0 || RHS0 == LHS1))) + return LHS; + } + + return nullptr; +} + +static Value *simplifyAndOrOfCmps(Value *Op0, Value *Op1, bool IsAnd) { // Look through casts of the 'and' operands to find compares. auto *Cast0 = dyn_cast(Op0); auto *Cast1 = dyn_cast(Op1); @@ -1560,13 +1724,18 @@ static Value *simplifyAndOrOfICmps(Value *Op0, Value *Op1, bool IsAnd) { Op1 = Cast1->getOperand(0); } - auto *Cmp0 = dyn_cast(Op0); - auto *Cmp1 = dyn_cast(Op1); - if (!Cmp0 || !Cmp1) - return nullptr; + Value *V = nullptr; + auto *ICmp0 = dyn_cast(Op0); + auto *ICmp1 = dyn_cast(Op1); + if (ICmp0 && ICmp1) + V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1) : + simplifyOrOfICmps(ICmp0, ICmp1); + + auto *FCmp0 = dyn_cast(Op0); + auto *FCmp1 = dyn_cast(Op1); + if (FCmp0 && FCmp1) + V = simplifyAndOrOfFCmps(FCmp0, FCmp1, IsAnd); - Value *V = - IsAnd ? simplifyAndOfICmps(Cmp0, Cmp1) : simplifyOrOfICmps(Cmp0, Cmp1); if (!V) return nullptr; if (!Cast0) @@ -1597,7 +1766,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X & 0 = 0 if (match(Op1, m_Zero())) - return Op1; + return Constant::getNullValue(Op0->getType()); // X & -1 = X if (match(Op1, m_AllOnes())) @@ -1645,7 +1814,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op1; } - if (Value *V = simplifyAndOrOfICmps(Op0, Op1, true)) + if (Value *V = simplifyAndOrOfCmps(Op0, Op1, true)) return V; // Try some generic simplifications for associative operations. @@ -1692,21 +1861,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X | undef -> -1 - if (match(Op1, m_Undef())) + // X | -1 = -1 + // Do not return Op1 because it may contain undef elements if it's a vector. + if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) return Constant::getAllOnesValue(Op0->getType()); // X | X = X - if (Op0 == Op1) - return Op0; - // X | 0 = X - if (match(Op1, m_Zero())) + if (Op0 == Op1 || match(Op1, m_Zero())) return Op0; - // X | -1 = -1 - if (match(Op1, m_AllOnes())) - return Op1; - // A | ~A = ~A | A = -1 if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) @@ -1766,7 +1930,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) return Op0; - if (Value *V = simplifyAndOrOfICmps(Op0, Op1, false)) + if (Value *V = simplifyAndOrOfCmps(Op0, Op1, false)) return V; // Try some generic simplifications for associative operations. @@ -2010,9 +2174,12 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, ConstantInt *LHSOffsetCI = dyn_cast(LHSOffset); ConstantInt *RHSOffsetCI = dyn_cast(RHSOffset); uint64_t LHSSize, RHSSize; + ObjectSizeOpts Opts; + Opts.NullIsUnknownSize = + NullPointerIsDefined(cast(LHS)->getFunction()); if (LHSOffsetCI && RHSOffsetCI && - getObjectSize(LHS, LHSSize, DL, TLI) && - getObjectSize(RHS, RHSSize, DL, TLI)) { + getObjectSize(LHS, LHSSize, DL, TLI, Opts) && + getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); if (!LHSOffsetValue.isNegative() && @@ -2401,6 +2568,20 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, Value *RHS) { + Type *ITy = GetCompareTy(RHS); // The return type. + + Value *X; + // Sign-bit checks can be optimized to true/false after unsigned + // floating-point casts: + // icmp slt (bitcast (uitofp X)), 0 --> false + // icmp sgt (bitcast (uitofp X)), -1 --> true + if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { + if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) + return ConstantInt::getFalse(ITy); + if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) + return ConstantInt::getTrue(ITy); + } + const APInt *C; if (!match(RHS, m_APInt(C))) return nullptr; @@ -2408,9 +2589,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, // Rule out tautological comparisons (eg., ult 0 or uge 0). ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); if (RHS_CR.isEmptySet()) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(ITy); if (RHS_CR.isFullSet()) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(ITy); // Find the range of possible values for binary operators. unsigned Width = C->getBitWidth(); @@ -2428,9 +2609,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, if (!LHS_CR.isFullSet()) { if (RHS_CR.contains(LHS_CR)) - return ConstantInt::getTrue(GetCompareTy(RHS)); + return ConstantInt::getTrue(ITy); if (RHS_CR.inverse().contains(LHS_CR)) - return ConstantInt::getFalse(GetCompareTy(RHS)); + return ConstantInt::getFalse(ITy); } return nullptr; @@ -2967,8 +3148,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Type *ITy = GetCompareTy(LHS); // The return type. // icmp X, X -> true/false - // X icmp undef -> true/false. For example, icmp ugt %X, undef -> false - // because X could be 0. + // icmp X, undef -> true/false because undef could be X. if (LHS == RHS || isa(RHS)) return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); @@ -3268,6 +3448,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(RetTy); } + // NaN is unordered; NaN is not ordered. + assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && + "Comparison must be either ordered or unordered"); + if (match(RHS, m_NaN())) + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + // fcmp pred x, undef and fcmp pred undef, x // fold to true if unordered, false if ordered if (isa(LHS) || isa(RHS)) { @@ -3284,27 +3470,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(RetTy); } - // Handle fcmp with constant RHS - const ConstantFP *CFP = nullptr; - if (const auto *RHSC = dyn_cast(RHS)) { - if (RHS->getType()->isVectorTy()) - CFP = dyn_cast_or_null(RHSC->getSplatValue()); - else - CFP = dyn_cast(RHSC); - } - if (CFP) { - // If the constant is a nan, see if we can fold the comparison based on it. - if (CFP->getValueAPF().isNaN()) { - if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" - return getFalse(RetTy); - assert(FCmpInst::isUnordered(Pred) && - "Comparison must be either ordered or unordered!"); - // True if unordered. - return getTrue(RetTy); - } + // Handle fcmp with constant RHS. + const APFloat *C; + if (match(RHS, m_APFloat(C))) { // Check whether the constant is an infinity. - if (CFP->getValueAPF().isInfinity()) { - if (CFP->getValueAPF().isNegative()) { + if (C->isInfinity()) { + if (C->isNegative()) { switch (Pred) { case FCmpInst::FCMP_OLT: // No value is ordered and less than negative infinity. @@ -3328,7 +3499,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } } - if (CFP->getValueAPF().isZero()) { + if (C->isZero()) { switch (Pred) { case FCmpInst::FCMP_UGE: if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) @@ -3342,6 +3513,28 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, default: break; } + } else if (C->isNegative()) { + assert(!C->isNaN() && "Unexpected NaN constant!"); + // TODO: We can catch more cases by using a range check rather than + // relying on CannotBeOrderedLessThanZero. + switch (Pred) { + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_UNE: + // (X >= 0) implies (X > C) when (C < 0) + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getTrue(RetTy); + break; + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_OLT: + // (X >= 0) implies !(X < C) when (C < 0) + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getFalse(RetTy); + break; + default: + break; + } } } @@ -3418,6 +3611,17 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, } } + // Same for GEPs. + if (auto *GEP = dyn_cast(I)) { + if (MaxRecurse) { + SmallVector NewOps(GEP->getNumOperands()); + transform(GEP->operands(), NewOps.begin(), + [&](Value *V) { return V == Op ? RepOp : V; }); + return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, + MaxRecurse - 1); + } + } + // TODO: We could hand off more cases to instsimplify here. // If all operands are constant after substituting Op for RepOp then we can @@ -3524,24 +3728,6 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, TrueVal, FalseVal)) return V; - if (CondVal->hasOneUse()) { - const APInt *C; - if (match(CmpRHS, m_APInt(C))) { - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) - return FalseVal; - // X < MIN ? T : F --> F - if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) - return FalseVal; - // X > MAX ? T : F --> F - if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) - return FalseVal; - } - } - // If we have an equality comparison, then we know the value in one of the // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. @@ -3574,37 +3760,41 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, - Value *FalseVal, const SimplifyQuery &Q, - unsigned MaxRecurse) { - // select true, X, Y -> X - // select false, X, Y -> Y - if (Constant *CB = dyn_cast(CondVal)) { - if (Constant *CT = dyn_cast(TrueVal)) - if (Constant *CF = dyn_cast(FalseVal)) - return ConstantFoldSelectInstruction(CB, CT, CF); - if (CB->isAllOnesValue()) +static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (auto *CondC = dyn_cast(Cond)) { + if (auto *TrueC = dyn_cast(TrueVal)) + if (auto *FalseC = dyn_cast(FalseVal)) + return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + + // select undef, X, Y -> X or Y + if (isa(CondC)) + return isa(FalseVal) ? FalseVal : TrueVal; + + // TODO: Vector constants with undef elements don't simplify. + + // select true, X, Y -> X + if (CondC->isAllOnesValue()) return TrueVal; - if (CB->isNullValue()) + // select false, X, Y -> Y + if (CondC->isNullValue()) return FalseVal; } - // select C, X, X -> X + // select ?, X, X -> X if (TrueVal == FalseVal) return TrueVal; - if (isa(CondVal)) { // select undef, X, Y -> X or Y - if (isa(FalseVal)) - return FalseVal; - return TrueVal; - } - if (isa(TrueVal)) // select C, undef, X -> X + if (isa(TrueVal)) // select ?, undef, X -> X return FalseVal; - if (isa(FalseVal)) // select C, X, undef -> X + if (isa(FalseVal)) // select ?, X, undef -> X return TrueVal; if (Value *V = - simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) + simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) + return V; + + if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal)) return V; return nullptr; @@ -3640,7 +3830,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef Ops, if (Ops.size() == 2) { // getelementptr P, 0 -> P. - if (match(Ops[1], m_Zero())) + if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy) return Ops[0]; Type *Ty = SrcTy; @@ -3649,13 +3839,13 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef Ops, uint64_t C; uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty); // getelementptr P, N -> P if P points to a type of zero size. - if (TyAllocSize == 0) + if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy) return Ops[0]; // The following transforms are only safe if the ptrtoint cast // doesn't truncate the pointers. if (Ops[1]->getType()->getScalarSizeInBits() == - Q.DL.getPointerSizeInBits(AS)) { + Q.DL.getIndexSizeInBits(AS)) { auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { if (match(P, m_Zero())) return Constant::getNullValue(GEPTy); @@ -3695,10 +3885,10 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef Ops, if (Q.DL.getTypeAllocSize(LastType) == 1 && all_of(Ops.slice(1).drop_back(1), [](Value *Idx) { return match(Idx, m_Zero()); })) { - unsigned PtrWidth = - Q.DL.getPointerSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); - if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == PtrWidth) { - APInt BasePtrOffset(PtrWidth, 0); + unsigned IdxWidth = + Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); + if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { + APInt BasePtrOffset(IdxWidth, 0); Value *StrippedBasePtr = Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, BasePtrOffset); @@ -3769,6 +3959,29 @@ Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); } +Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, + const SimplifyQuery &Q) { + // Try to constant fold. + auto *VecC = dyn_cast(Vec); + auto *ValC = dyn_cast(Val); + auto *IdxC = dyn_cast(Idx); + if (VecC && ValC && IdxC) + return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC); + + // Fold into undef if index is out of bounds. + if (auto *CI = dyn_cast(Idx)) { + uint64_t NumElements = cast(Vec->getType())->getNumElements(); + if (CI->uge(NumElements)) + return UndefValue::get(Vec->getType()); + } + + // If index is undef, it might be out of bounds (see above case) + if (isa(Idx)) + return UndefValue::get(Vec->getType()); + + return nullptr; +} + /// Given operands for an ExtractValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef Idxs, @@ -3817,9 +4030,18 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQ // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. - if (auto *IdxC = dyn_cast(Idx)) + if (auto *IdxC = dyn_cast(Idx)) { + if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) + // definitely out of bounds, thus undefined result + return UndefValue::get(Vec->getType()->getVectorElementType()); if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) return Elt; + } + + // An undef extract index can be arbitrarily chosen to be an out-of-range + // index value, which would result in the instruction being undef. + if (isa(Idx)) + return UndefValue::get(Vec->getType()->getVectorElementType()); return nullptr; } @@ -3857,7 +4079,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // instruction, we cannot return X as the result of the PHI node unless it // dominates the PHI block. if (HasUndefInput) - return ValueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; + return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; return CommonValue; } @@ -4034,6 +4256,28 @@ Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } +static Constant *propagateNaN(Constant *In) { + // If the input is a vector with undef elements, just return a default NaN. + if (!In->isNaN()) + return ConstantFP::getNaN(In->getType()); + + // Propagate the existing NaN constant when possible. + // TODO: Should we quiet a signaling NaN? + return In; +} + +static Constant *simplifyFPBinop(Value *Op0, Value *Op1) { + if (isa(Op0) || isa(Op1)) + return ConstantFP::getNaN(Op0->getType()); + + if (match(Op0, m_NaN())) + return propagateNaN(cast(Op0)); + if (match(Op1, m_NaN())) + return propagateNaN(cast(Op1)); + + return nullptr; +} + /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -4041,29 +4285,28 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + // fadd X, -0 ==> X - if (match(Op1, m_NegZero())) + if (match(Op1, m_NegZeroFP())) return Op0; // fadd X, 0 ==> X, when we know X is not -0 - if (match(Op1, m_Zero()) && + if (match(Op1, m_PosZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; - // fadd [nnan ninf] X, (fsub [nnan ninf] 0, X) ==> 0 - // where nnan and ninf have to occur at least once somewhere in this - // expression - Value *SubOp = nullptr; - if (match(Op1, m_FSub(m_AnyZero(), m_Specific(Op0)))) - SubOp = Op1; - else if (match(Op0, m_FSub(m_AnyZero(), m_Specific(Op1)))) - SubOp = Op0; - if (SubOp) { - Instruction *FSub = cast(SubOp); - if ((FMF.noNaNs() || FSub->hasNoNaNs()) && - (FMF.noInfs() || FSub->hasNoInfs())) - return Constant::getNullValue(Op0->getType()); - } + // With nnan: (+/-0.0 - X) + X --> 0.0 (and commuted variant) + // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. + // Negative zeros are allowed because we always end up with positive zero: + // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 + // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 + // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 + if (FMF.noNaNs() && (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || + match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) + return ConstantFP::getNullValue(Op0->getType()); return nullptr; } @@ -4075,23 +4318,27 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; - // fsub X, 0 ==> X - if (match(Op1, m_Zero())) + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + + // fsub X, +0 ==> X + if (match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZero()) && + if (match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X Value *X; - if (match(Op0, m_NegZero()) && match(Op1, m_FSub(m_NegZero(), m_Value(X)))) + if (match(Op0, m_NegZeroFP()) && + match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. - if (FMF.noSignedZeros() && match(Op0, m_AnyZero()) && - match(Op1, m_FSub(m_AnyZero(), m_Value(X)))) + if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && + match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) return X; // fsub nnan x, x ==> 0.0 @@ -4107,13 +4354,25 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) return C; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; + // fmul X, 1.0 ==> X if (match(Op1, m_FPOne())) return Op0; // fmul nnan nsz X, 0 ==> 0 - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) - return Op1; + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + + // sqrt(X) * sqrt(X) --> X, if we can: + // 1. Remove the intermediate rounding (reassociate). + // 2. Ignore non-zero negative numbers because sqrt would produce NAN. + // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. + Value *X; + if (Op0 == Op1 && match(Op0, m_Intrinsic(m_Value(X))) && + FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) + return X; return nullptr; } @@ -4139,13 +4398,8 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) return C; - // undef / X -> undef (the undef could be a snan). - if (match(Op0, m_Undef())) - return Op0; - - // X / undef -> undef - if (match(Op1, m_Undef())) - return Op1; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; // X / 1.0 -> X if (match(Op1, m_FPOne())) @@ -4154,14 +4408,20 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, // 0 / X -> 0 // Requires that NaNs are off (X could be zero) and signed zeroes are // ignored (X could be positive or negative, so the output sign is unknown). - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) - return Op0; + if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); if (FMF.noNaNs()) { // X / X -> 1.0 is legal when NaNs are ignored. + // We can ignore infinities because INF/INF is NaN. if (Op0 == Op1) return ConstantFP::get(Op0->getType(), 1.0); + // (X * Y) / Y --> X if we can reassociate to the above form. + Value *X; + if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1)))) + return X; + // -X / X -> -1.0 and // X / -X -> -1.0 are legal when NaNs are ignored. // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. @@ -4185,19 +4445,20 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) return C; - // undef % X -> undef (the undef could be a snan). - if (match(Op0, m_Undef())) - return Op0; - - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; + if (Constant *C = simplifyFPBinop(Op0, Op1)) + return C; - // 0 % X -> 0 - // Requires that NaNs are off (X could be zero) and signed zeroes are - // ignored (X could be positive or negative, so the output sign is unknown). - if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) - return Op0; + // Unlike fdiv, the result of frem always matches the sign of the dividend. + // The constant match may include undef elements in a vector, so return a full + // zero constant as the result. + if (FMF.noNaNs()) { + // +0 % X -> 0 + if (match(Op0, m_PosZeroFP())) + return ConstantFP::getNullValue(Op0->getType()); + // -0 % X -> -0 + if (match(Op0, m_NegZeroFP())) + return ConstantFP::getNegativeZero(Op0->getType()); + } return nullptr; } @@ -4388,88 +4649,131 @@ static bool maskIsAllZeroOrUndef(Value *Mask) { return true; } -template -static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, - const SimplifyQuery &Q, unsigned MaxRecurse) { +static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, + const SimplifyQuery &Q) { + // Idempotent functions return the same result when called repeatedly. Intrinsic::ID IID = F->getIntrinsicID(); - unsigned NumOperands = std::distance(ArgBegin, ArgEnd); + if (IsIdempotent(IID)) + if (auto *II = dyn_cast(Op0)) + if (II->getIntrinsicID() == IID) + return II; - // Unary Ops - if (NumOperands == 1) { - // Perform idempotent optimizations - if (IsIdempotent(IID)) { - if (IntrinsicInst *II = dyn_cast(*ArgBegin)) { - if (II->getIntrinsicID() == IID) - return II; - } - } - - switch (IID) { - case Intrinsic::fabs: { - if (SignBitMustBeZero(*ArgBegin, Q.TLI)) - return *ArgBegin; - return nullptr; - } - default: - return nullptr; - } + Value *X; + switch (IID) { + case Intrinsic::fabs: + if (SignBitMustBeZero(Op0, Q.TLI)) return Op0; + break; + case Intrinsic::bswap: + // bswap(bswap(x)) -> x + if (match(Op0, m_BSwap(m_Value(X)))) return X; + break; + case Intrinsic::bitreverse: + // bitreverse(bitreverse(x)) -> x + if (match(Op0, m_BitReverse(m_Value(X)))) return X; + break; + case Intrinsic::exp: + // exp(log(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic(m_Value(X)))) return X; + break; + case Intrinsic::exp2: + // exp2(log2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic(m_Value(X)))) return X; + break; + case Intrinsic::log: + // log(exp(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic(m_Value(X)))) return X; + break; + case Intrinsic::log2: + // log2(exp2(x)) -> x + if (Q.CxtI->hasAllowReassoc() && + match(Op0, m_Intrinsic(m_Value(X)))) return X; + break; + default: + break; } - // Binary Ops - if (NumOperands == 2) { - Value *LHS = *ArgBegin; - Value *RHS = *(ArgBegin + 1); - Type *ReturnType = F->getReturnType(); + return nullptr; +} - switch (IID) { - case Intrinsic::usub_with_overflow: - case Intrinsic::ssub_with_overflow: { - // X - X -> { 0, false } - if (LHS == RHS) - return Constant::getNullValue(ReturnType); +static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, + const SimplifyQuery &Q) { + Intrinsic::ID IID = F->getIntrinsicID(); + Type *ReturnType = F->getReturnType(); + switch (IID) { + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + // X - X -> { 0, false } + if (Op0 == Op1) + return Constant::getNullValue(ReturnType); + // X - undef -> undef + // undef - X -> undef + if (isa(Op0) || isa(Op1)) + return UndefValue::get(ReturnType); + break; + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + // X + undef -> undef + if (isa(Op0) || isa(Op1)) + return UndefValue::get(ReturnType); + break; + case Intrinsic::umul_with_overflow: + case Intrinsic::smul_with_overflow: + // 0 * X -> { 0, false } + // X * 0 -> { 0, false } + if (match(Op0, m_Zero()) || match(Op1, m_Zero())) + return Constant::getNullValue(ReturnType); + // undef * X -> { 0, false } + // X * undef -> { 0, false } + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getNullValue(ReturnType); + break; + case Intrinsic::load_relative: + if (auto *C0 = dyn_cast(Op0)) + if (auto *C1 = dyn_cast(Op1)) + return SimplifyRelativeLoad(C0, C1, Q.DL); + break; + case Intrinsic::powi: + if (auto *Power = dyn_cast(Op1)) { + // powi(x, 0) -> 1.0 + if (Power->isZero()) + return ConstantFP::get(Op0->getType(), 1.0); + // powi(x, 1) -> x + if (Power->isOne()) + return Op0; + } + break; + case Intrinsic::maxnum: + case Intrinsic::minnum: + // If one argument is NaN, return the other argument. + if (match(Op0, m_NaN())) return Op1; + if (match(Op1, m_NaN())) return Op0; + break; + default: + break; + } - // X - undef -> undef - // undef - X -> undef - if (isa(LHS) || isa(RHS)) - return UndefValue::get(ReturnType); + return nullptr; +} - return nullptr; - } - case Intrinsic::uadd_with_overflow: - case Intrinsic::sadd_with_overflow: { - // X + undef -> undef - if (isa(LHS) || isa(RHS)) - return UndefValue::get(ReturnType); +template +static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, + const SimplifyQuery &Q) { + // Intrinsics with no operands have some kind of side effect. Don't simplify. + unsigned NumOperands = std::distance(ArgBegin, ArgEnd); + if (NumOperands == 0) + return nullptr; - return nullptr; - } - case Intrinsic::umul_with_overflow: - case Intrinsic::smul_with_overflow: { - // 0 * X -> { 0, false } - // X * 0 -> { 0, false } - if (match(LHS, m_Zero()) || match(RHS, m_Zero())) - return Constant::getNullValue(ReturnType); - - // undef * X -> { 0, false } - // X * undef -> { 0, false } - if (match(LHS, m_Undef()) || match(RHS, m_Undef())) - return Constant::getNullValue(ReturnType); + Intrinsic::ID IID = F->getIntrinsicID(); + if (NumOperands == 1) + return simplifyUnaryIntrinsic(F, ArgBegin[0], Q); - return nullptr; - } - case Intrinsic::load_relative: { - Constant *C0 = dyn_cast(LHS); - Constant *C1 = dyn_cast(RHS); - if (C0 && C1) - return SimplifyRelativeLoad(C0, C1, Q.DL); - return nullptr; - } - default: - return nullptr; - } - } + if (NumOperands == 2) + return simplifyBinaryIntrinsic(F, ArgBegin[0], ArgBegin[1], Q); - // Simplify calls to llvm.masked.load.* + // Handle intrinsics with 3 or more arguments. switch (IID) { case Intrinsic::masked_load: { Value *MaskArg = ArgBegin[2]; @@ -4479,6 +4783,19 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, return PassthruArg; return nullptr; } + case Intrinsic::fshl: + case Intrinsic::fshr: { + Value *ShAmtArg = ArgBegin[2]; + const APInt *ShAmtC; + if (match(ShAmtArg, m_APInt(ShAmtC))) { + // If there's effectively no shift, return the 1st arg or 2nd arg. + // TODO: For vectors, we could check each element of a non-splat constant. + APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); + if (ShAmtC->urem(BitWidth).isNullValue()) + return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; + } + return nullptr; + } default: return nullptr; } @@ -4503,7 +4820,7 @@ static Value *SimplifyCall(ImmutableCallSite CS, Value *V, IterTy ArgBegin, return nullptr; if (F->isIntrinsic()) - if (Value *Ret = SimplifyIntrinsic(F, ArgBegin, ArgEnd, Q, MaxRecurse)) + if (Value *Ret = simplifyIntrinsic(F, ArgBegin, ArgEnd, Q)) return Ret; if (!canConstantFoldCallTo(CS, F)) @@ -4532,6 +4849,12 @@ Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, return ::SimplifyCall(CS, V, Args.begin(), Args.end(), Q, RecursionLimit); } +Value *llvm::SimplifyCall(ImmutableCallSite ICS, const SimplifyQuery &Q) { + CallSite CS(const_cast(ICS.getInstruction())); + return ::SimplifyCall(CS, CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), + Q, RecursionLimit); +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. @@ -4637,6 +4960,12 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, IV->getIndices(), Q); break; } + case Instruction::InsertElement: { + auto *IE = cast(I); + Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), + IE->getOperand(2), Q); + break; + } case Instruction::ExtractValue: { auto *EVI = cast(I); Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), @@ -4660,8 +4989,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, break; case Instruction::Call: { CallSite CS(cast(I)); - Result = SimplifyCall(CS, CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), - Q); + Result = SimplifyCall(CS, Q); break; } #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: @@ -4690,7 +5018,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, return Result == I ? UndefValue::get(I->getType()) : Result; } -/// \brief Implementation of recursive simplification through an instruction's +/// Implementation of recursive simplification through an instruction's /// uses. /// /// This is the common implementation of the recursive simplification routines.