From abb2e449360224c78960fed7298b9ccb4d26debe Mon Sep 17 00:00:00 2001 From: Joseph Tremoulet Date: Thu, 13 Jun 2019 15:24:11 +0000 Subject: [PATCH] [EarlyCSE] Ensure equal keys have the same hash value Summary: The logic in EarlyCSE that looks through 'not' operations in the predicate recognizes e.g. that `select (not (cmp sgt X, Y)), X, Y` is equivalent to `select (cmp sgt X, Y), Y, X`. Without this change, however, only the latter is recognized as a form of `smin X, Y`, so the two expressions receive different hash codes. This leads to missed optimization opportunities when the quadratic probing for the two hashes doesn't happen to collide, and assertion failures when probing doesn't collide on insertion but does collide on a subsequent table grow operation. This change inverts the order of some of the pattern matching, checking first for the optional `not` and then for the min/max/abs patterns, so that e.g. both expressions above are recognized as a form of `smin X, Y`. It also adds an assertion to isEqual verifying that it implies equal hash codes; this fires when there's a collision during insertion, not just grow, and so will make it easier to notice if these functions fall out of sync again. A new flag --earlycse-debug-hash is added which can be used when changing the hash function; it forces hash collisions so that any pair of values inserted which compare as equal but hash differently will be caught by the isEqual assertion. Reviewers: spatel, nikic Reviewed By: spatel, nikic Subscribers: lebedev.ri, arsenm, craig.topper, efriedma, hiraditya, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D62644 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@363274 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/llvm/Analysis/ValueTracking.h | 6 ++ lib/Analysis/ValueTracking.cpp | 12 ++- lib/Transforms/Scalar/EarlyCSE.cpp | 174 +++++++++++++++++++++------------- test/Transforms/EarlyCSE/commute.ll | 53 ++++++++--- 4 files changed, 166 insertions(+), 79 deletions(-) diff --git a/include/llvm/Analysis/ValueTracking.h b/include/llvm/Analysis/ValueTracking.h index d14da32ae8b..f14c2a4f322 100644 --- a/include/llvm/Analysis/ValueTracking.h +++ b/include/llvm/Analysis/ValueTracking.h @@ -606,6 +606,12 @@ class Value; return Result; } + /// Determine the pattern that a select with the given compare as its + /// predicate and given values as its true/false operands would match. + SelectPatternResult matchDecomposedSelectPattern( + CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp = nullptr, unsigned Depth = 0); + /// Return the canonical comparison predicate for the specified /// minimum/maximum flavor. CmpInst::Predicate getMinMaxPred(SelectPatternFlavor SPF, diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index e1326548f85..be6efba3db0 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -5073,11 +5073,19 @@ SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, CmpInst *CmpI = dyn_cast(SI->getCondition()); if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false}; + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + + return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS, + CastOp, Depth); +} + +SelectPatternResult llvm::matchDecomposedSelectPattern( + CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp, unsigned Depth) { CmpInst::Predicate Pred = CmpI->getPredicate(); Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); FastMathFlags FMF; if (isa(CmpI)) FMF = CmpI->getFastMathFlags(); diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 892a12315ce..f9c6f88b9ec 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -80,6 +80,11 @@ static cl::opt EarlyCSEMssaOptCap( cl::desc("Enable imprecision in EarlyCSE in pathological cases, in exchange " "for faster compile. Caps the MemorySSA clobbering calls.")); +static cl::opt EarlyCSEDebugHash( + "earlycse-debug-hash", cl::init(false), cl::Hidden, + cl::desc("Perform extra assertion checking to verify that SimpleValue's hash " + "function is well-behaved w.r.t. its isEqual predicate")); + //===----------------------------------------------------------------------===// // SimpleValue //===----------------------------------------------------------------------===// @@ -130,22 +135,33 @@ template <> struct DenseMapInfo { } // end namespace llvm -/// Match a 'select' including an optional 'not' of the condition. -static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, - Value *&T, Value *&F) { - if (match(V, m_Select(m_Value(Cond), m_Value(T), m_Value(F)))) { - // Look through a 'not' of the condition operand by swapping true/false. - Value *CondNot; - if (match(Cond, m_Not(m_Value(CondNot)))) { - Cond = CondNot; - std::swap(T, F); - } - return true; +/// Match a 'select' including an optional 'not's of the condition. +static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A, + Value *&B, + SelectPatternFlavor &Flavor) { + // Return false if V is not even a select. + if (!match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B)))) + return false; + + // Look through a 'not' of the condition operand by swapping A/B. + Value *CondNot; + if (match(Cond, m_Not(m_Value(CondNot)))) { + Cond = CondNot; + std::swap(A, B); } - return false; + + // Set flavor if we find a match, or set it to unknown otherwise; in + // either case, return true to indicate that this is a select we can + // process. + if (auto *CmpI = dyn_cast(Cond)) + Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor; + else + Flavor = SPF_UNKNOWN; + + return true; } -unsigned DenseMapInfo::getHashValue(SimpleValue Val) { +static unsigned getHashValueImpl(SimpleValue Val) { Instruction *Inst = Val.Inst; // Hash in all of the operands as pointers. if (BinaryOperator *BinOp = dyn_cast(Inst)) { @@ -168,40 +184,41 @@ unsigned DenseMapInfo::getHashValue(SimpleValue Val) { return hash_combine(Inst->getOpcode(), Pred, LHS, RHS); } - // Hash min/max/abs (cmp + select) to allow for commuted operands. - // Min/max may also have non-canonical compare predicate (eg, the compare for - // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the - // compare. - Value *A, *B; - SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; - // TODO: We should also detect FP min/max. - if (SPF == SPF_SMIN || SPF == SPF_SMAX || - SPF == SPF_UMIN || SPF == SPF_UMAX) { - if (A > B) - std::swap(A, B); - return hash_combine(Inst->getOpcode(), SPF, A, B); - } - if (SPF == SPF_ABS || SPF == SPF_NABS) { - // ABS/NABS always puts the input in A and its negation in B. - return hash_combine(Inst->getOpcode(), SPF, A, B); - } - // Hash general selects to allow matching commuted true/false operands. - Value *Cond, *TVal, *FVal; - if (matchSelectWithOptionalNotCond(Inst, Cond, TVal, FVal)) { + SelectPatternFlavor SPF; + Value *Cond, *A, *B; + if (matchSelectWithOptionalNotCond(Inst, Cond, A, B, SPF)) { + // Hash min/max/abs (cmp + select) to allow for commuted operands. + // Min/max may also have non-canonical compare predicate (eg, the compare for + // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the + // compare. + // TODO: We should also detect FP min/max. + if (SPF == SPF_SMIN || SPF == SPF_SMAX || + SPF == SPF_UMIN || SPF == SPF_UMAX) { + if (A > B) + std::swap(A, B); + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // ABS/NABS always puts the input in A and its negation in B. + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + + // Hash general selects to allow matching commuted true/false operands. + // If we do not have a compare as the condition, just hash in the condition. CmpInst::Predicate Pred; Value *X, *Y; if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y)))) - return hash_combine(Inst->getOpcode(), Cond, TVal, FVal); + return hash_combine(Inst->getOpcode(), Cond, A, B); // Similar to cmp normalization (above) - canonicalize the predicate value: - // select (icmp Pred, X, Y), T, F --> select (icmp InvPred, X, Y), F, T + // select (icmp Pred, X, Y), A, B --> select (icmp InvPred, X, Y), B, A if (CmpInst::getInversePredicate(Pred) < Pred) { Pred = CmpInst::getInversePredicate(Pred); - std::swap(TVal, FVal); + std::swap(A, B); } - return hash_combine(Inst->getOpcode(), Pred, X, Y, TVal, FVal); + return hash_combine(Inst->getOpcode(), Pred, X, Y, A, B); } if (CastInst *CI = dyn_cast(Inst)) @@ -227,7 +244,19 @@ unsigned DenseMapInfo::getHashValue(SimpleValue Val) { hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); } -bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { +unsigned DenseMapInfo::getHashValue(SimpleValue Val) { +#ifndef NDEBUG + // If -earlycse-debug-hash was specified, return a constant -- this + // will force all hashing to collide, so we'll exhaustively search + // the table for a match, and the assertion in isEqual will fire if + // there's a bug causing equal keys to hash differently. + if (EarlyCSEDebugHash) + return 0; +#endif + return getHashValueImpl(Val); +} + +static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; if (LHS.isSentinel() || RHS.isSentinel()) @@ -263,39 +292,47 @@ bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { // Min/max/abs can occur with commuted operands, non-canonical predicates, // and/or non-canonical operands. - Value *LHSA, *LHSB; - SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor; - // TODO: We should also detect FP min/max. - if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || - LSPF == SPF_UMIN || LSPF == SPF_UMAX || - LSPF == SPF_ABS || LSPF == SPF_NABS) { - Value *RHSA, *RHSB; - SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; + // Selects can be non-trivially equivalent via inverted conditions and swaps. + SelectPatternFlavor LSPF, RSPF; + Value *CondL, *CondR, *LHSA, *RHSA, *LHSB, *RHSB; + if (matchSelectWithOptionalNotCond(LHSI, CondL, LHSA, LHSB, LSPF) && + matchSelectWithOptionalNotCond(RHSI, CondR, RHSA, RHSB, RSPF)) { if (LSPF == RSPF) { - // Abs results are placed in a defined order by matchSelectPattern. - if (LSPF == SPF_ABS || LSPF == SPF_NABS) + // TODO: We should also detect FP min/max. + if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || + LSPF == SPF_UMIN || LSPF == SPF_UMAX) + return ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA)); + + if (LSPF == SPF_ABS || LSPF == SPF_NABS) { + // Abs results are placed in a defined order by matchSelectPattern. return LHSA == RHSA && LHSB == RHSB; - return ((LHSA == RHSA && LHSB == RHSB) || - (LHSA == RHSB && LHSB == RHSA)); - } - } + } - // Selects can be non-trivially equivalent via inverted conditions and swaps. - Value *CondL, *CondR, *TrueL, *TrueR, *FalseL, *FalseR; - if (matchSelectWithOptionalNotCond(LHSI, CondL, TrueL, FalseL) && - matchSelectWithOptionalNotCond(RHSI, CondR, TrueR, FalseR)) { - // select Cond, T, F <--> select not(Cond), F, T - if (CondL == CondR && TrueL == TrueR && FalseL == FalseR) - return true; + // select Cond, A, B <--> select not(Cond), B, A + if (CondL == CondR && LHSA == RHSA && LHSB == RHSB) + return true; + } // If the true/false operands are swapped and the conditions are compares // with inverted predicates, the selects are equal: - // select (icmp Pred, X, Y), T, F <--> select (icmp InvPred, X, Y), F, T + // select (icmp Pred, X, Y), A, B <--> select (icmp InvPred, X, Y), B, A // - // This also handles patterns with a double-negation because we looked - // through a 'not' in the matching function and swapped T/F: - // select (cmp Pred, X, Y), T, F <--> select (not (cmp InvPred, X, Y)), T, F - if (TrueL == FalseR && FalseL == TrueR) { + // This also handles patterns with a double-negation in the sense of not + + // inverse, because we looked through a 'not' in the matching function and + // swapped A/B: + // select (cmp Pred, X, Y), A, B <--> select (not (cmp InvPred, X, Y)), B, A + // + // This intentionally does NOT handle patterns with a double-negation in + // the sense of not + not, because doing so could result in values + // comparing + // as equal that hash differently in the min/max/abs cases like: + // select (cmp slt, X, Y), X, Y <--> select (not (not (cmp slt, X, Y))), X, Y + // ^ hashes as min ^ would not hash as min + // In the context of the EarlyCSE pass, however, such cases never reach + // this code, as we simplify the double-negation before hashing the second + // select (and so still succeed at CSEing them). + if (LHSA == RHSB && LHSB == RHSA) { CmpInst::Predicate PredL, PredR; Value *X, *Y; if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) && @@ -308,6 +345,15 @@ bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { return false; } +bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { + // These comparisons are nontrivial, so assert that equality implies + // hash equality (DenseMap demands this as an invariant). + bool Result = isEqualImpl(LHS, RHS); + assert(!Result || (LHS.isSentinel() && LHS.Inst == RHS.Inst) || + getHashValueImpl(LHS) == getHashValueImpl(RHS)); + return Result; +} + //===----------------------------------------------------------------------===// // CallValue //===----------------------------------------------------------------------===// diff --git a/test/Transforms/EarlyCSE/commute.ll b/test/Transforms/EarlyCSE/commute.ll index 32dd55b5bce..572336c4ec4 100644 --- a/test/Transforms/EarlyCSE/commute.ll +++ b/test/Transforms/EarlyCSE/commute.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -S -early-cse | FileCheck %s +; RUN: opt < %s -S -early-cse -earlycse-debug-hash | FileCheck %s ; RUN: opt < %s -S -basicaa -early-cse-memssa | FileCheck %s define void @test1(float %A, float %B, float* %PA, float* %PB) { @@ -108,14 +108,13 @@ define i1 @smin_swapped(i8 %a, i8 %b) { } ; Min/max can also have an inverted predicate and select operands. -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @smin_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @smin_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp slt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -155,13 +154,12 @@ define i8 @smax_swapped(i8 %a, i8 %b) { ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @smax_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @smax_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp sgt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -203,13 +201,12 @@ define <2 x i8> @umin_swapped(<2 x i8> %a, <2 x i8> %b) { ret <2 x i8> %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @umin_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @umin_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp ult i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -250,13 +247,12 @@ define i8 @umax_swapped(i8 %a, i8 %b) { ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @umax_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @umax_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp ugt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -302,14 +298,13 @@ define i8 @abs_swapped(i8 %a) { ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i8 @abs_inverted(i8 %a) { ; CHECK-LABEL: @abs_inverted( ; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]] ; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A]], 0 ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]] -; CHECK: ret i8 +; CHECK-NEXT: ret i8 [[M1]] ; %neg = sub i8 0, %a %cmp1 = icmp sgt i8 %a, 0 @@ -337,14 +332,13 @@ define i8 @nabs_swapped(i8 %a) { ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i8 @nabs_inverted(i8 %a) { ; CHECK-LABEL: @nabs_inverted( ; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]] ; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A]], 0 ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]] -; CHECK: ret i8 +; CHECK-NEXT: ret i8 0 ; %neg = sub i8 0, %a %cmp1 = icmp slt i8 %a, 0 @@ -646,3 +640,36 @@ define i32 @select_not_invert_pred_cond_wrong_select_op(i8 %x, i8 %y, i32 %t, i3 %r = sub i32 %m2, %m1 ret i32 %r } + + +; This test is a reproducer for a bug involving inverted min/max selects +; hashing differently but comparing as equal. It exhibits such a pair of +; values, and we run this test with -earlycse-debug-hash which would catch +; the disagreement and fail if it regressed. This test also includes a +; negation of each negation to check for the same issue one level deeper. +define void @not_not_min(i32* %px, i32* %py, i32* %pout) { +; CHECK-LABEL: @not_not_min( +; CHECK-NEXT: [[X:%.*]] = load volatile i32, i32* [[PX:%.*]] +; CHECK-NEXT: [[Y:%.*]] = load volatile i32, i32* [[PY:%.*]] +; CHECK-NEXT: [[CMPA:%.*]] = icmp slt i32 [[X]], [[Y]] +; CHECK-NEXT: [[CMPB:%.*]] = xor i1 [[CMPA]], true +; CHECK-NEXT: [[RA:%.*]] = select i1 [[CMPA]], i32 [[X]], i32 [[Y]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT:%.*]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]] +; CHECK-NEXT: ret void +; + %x = load volatile i32, i32* %px + %y = load volatile i32, i32* %py + %cmpa = icmp slt i32 %x, %y + %cmpb = xor i1 %cmpa, -1 + %cmpc = xor i1 %cmpb, -1 + %ra = select i1 %cmpa, i32 %x, i32 %y + %rb = select i1 %cmpb, i32 %y, i32 %x + %rc = select i1 %cmpc, i32 %x, i32 %y + store volatile i32 %ra, i32* %pout + store volatile i32 %rb, i32* %pout + store volatile i32 %rc, i32* %pout + + ret void +} -- 2.11.0