From 3334995891e808518824f2fe8bbc434be7e67307 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Tue, 13 Nov 2018 22:47:24 +0000 Subject: [PATCH] [InstCombine] canonicalize rotate patterns with cmp/select The cmp+branch variant of this pattern is shown in: https://bugs.llvm.org/show_bug.cgi?id=34924 ...and as discussed there, we probably can't transform that without a rotate intrinsic. We do have that now via funnel shift, but we're not quite ready to canonicalize IR to that form yet. The case with 'select' should already be transformed though, so that's this patch. The sequence with negation followed by masking is what we use in the backend and partly in clang (though that part should be updated). https://rise4fun.com/Alive/TplC %cmp = icmp eq i32 %shamt, 0 %sub = sub i32 32, %shamt %shr = lshr i32 %x, %shamt %shl = shl i32 %x, %sub %or = or i32 %shr, %shl %r = select i1 %cmp, i32 %x, i32 %or => %neg = sub i32 0, %shamt %masked = and i32 %shamt, 31 %maskedneg = and i32 %neg, 31 %shl2 = lshr i32 %x, %masked %shr2 = shl i32 %x, %maskedneg %r = or i32 %shl2, %shr2 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@346807 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/InstCombine/InstCombineSelect.cpp | 63 ++++++++++++++++ test/Transforms/InstCombine/rotate.ll | 95 +++++++++++++++--------- 2 files changed, 121 insertions(+), 37 deletions(-) diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 88a72bb8eb5..26d0b522f01 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1546,6 +1546,66 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } +/// Try to reduce a rotate pattern that includes a compare and select into a +/// sequence of ALU ops only. Example: +/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) +/// --> (a >> (-b & 31)) | (a << (b & 31)) +static Instruction *foldSelectRotate(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + // The false value of the select must be a rotate of the true value. + Value *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + return nullptr; + + auto ShiftOpcode0 = cast(Or0)->getOpcode(); + auto ShiftOpcode1 = cast(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // We have one of these patterns so far: + // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) + // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) + // This must be a power-of-2 rotate for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + // Check the shift amounts to see if they are an opposite pair. + Value *ShAmt; + if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) + ShAmt = SA0; + else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) + ShAmt = SA1; + else + return nullptr; + + // Finally, see if the select is filtering out a shift-by-zero. + Value *Cond = Sel.getCondition(); + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // Convert to safely bitmasked shifts. + // TODO: When we can canonicalize to funnel shift intrinsics without risk of + // performance regressions, replace this sequence with that call. + Value *NegShAmt = Builder.CreateNeg(ShAmt); + Value *MaskedShAmt = Builder.CreateAnd(ShAmt, Width - 1); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, Width - 1); + Value *NewSA0 = ShAmt == SA0 ? MaskedShAmt : MaskedNegShAmt; + Value *NewSA1 = ShAmt == SA1 ? MaskedShAmt : MaskedNegShAmt; + Value *NewSh0 = Builder.CreateBinOp(ShiftOpcode0, TVal, NewSA0); + Value *NewSh1 = Builder.CreateBinOp(ShiftOpcode1, TVal, NewSA1); + return BinaryOperator::CreateOr(NewSh0, NewSh1); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2010,5 +2070,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI)) return Select; + if (Instruction *Rot = foldSelectRotate(SI, Builder)) + return Rot; + return nullptr; } diff --git a/test/Transforms/InstCombine/rotate.ll b/test/Transforms/InstCombine/rotate.ll index 4401539220a..6150063ab72 100644 --- a/test/Transforms/InstCombine/rotate.ll +++ b/test/Transforms/InstCombine/rotate.ll @@ -309,16 +309,16 @@ define i8 @rotateleft_8_neg_mask_wide_amount_commute(i8 %v, i32 %shamt) { ret i8 %ret } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i32 @rotr_select(i32 %x, i32 %shamt) { ; CHECK-LABEL: @rotr_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i32 32, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[X]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i32 [[X]], i32 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i32 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[SHAMT]], 31 +; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP1]], 31 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[X:%.*]], [[TMP2]] +; CHECK-NEXT: [[TMP5:%.*]] = shl i32 [[X]], [[TMP3]] +; CHECK-NEXT: [[R:%.*]] = or i32 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i32 [[R]] ; %cmp = icmp eq i32 %shamt, 0 @@ -330,16 +330,16 @@ define i32 @rotr_select(i32 %x, i32 %shamt) { ret i32 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i8 @rotr_select_commute(i8 %x, i8 %shamt) { ; CHECK-LABEL: @rotr_select_commute( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i8 8, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i8 [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[SHL:%.*]] = shl i8 [[X]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i8 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i8 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[SHAMT]], 7 +; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP1]], 7 +; CHECK-NEXT: [[TMP4:%.*]] = shl i8 [[X:%.*]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = lshr i8 [[X]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = or i8 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i8 [[R]] ; %cmp = icmp eq i8 %shamt, 0 @@ -351,16 +351,16 @@ define i8 @rotr_select_commute(i8 %x, i8 %shamt) { ret i8 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. define i16 @rotl_select(i16 %x, i16 %shamt) { ; CHECK-LABEL: @rotl_select( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i16 16, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i16 [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl i16 [[X]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or i16 [[SHR]], [[SHL]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i16 [[X]], i16 [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = sub i16 0, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[SHAMT]], 15 +; CHECK-NEXT: [[TMP3:%.*]] = and i16 [[TMP1]], 15 +; CHECK-NEXT: [[TMP4:%.*]] = lshr i16 [[X:%.*]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shl i16 [[X]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = or i16 [[TMP4]], [[TMP5]] ; CHECK-NEXT: ret i16 [[R]] ; %cmp = icmp eq i16 %shamt, 0 @@ -372,24 +372,45 @@ define i16 @rotl_select(i16 %x, i16 %shamt) { ret i16 %r } -; TODO: Convert select pattern to masked shift that ends in 'or'. +; Convert select pattern to masked shift that ends in 'or'. -define i64 @rotl_select_commute(i64 %x, i64 %shamt) { +define <2 x i64> @rotl_select_commute(<2 x i64> %x, <2 x i64> %shamt) { ; CHECK-LABEL: @rotl_select_commute( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[SHAMT:%.*]], 0 -; CHECK-NEXT: [[SUB:%.*]] = sub i64 64, [[SHAMT]] -; CHECK-NEXT: [[SHR:%.*]] = lshr i64 [[X:%.*]], [[SUB]] -; CHECK-NEXT: [[SHL:%.*]] = shl i64 [[X]], [[SHAMT]] -; CHECK-NEXT: [[OR:%.*]] = or i64 [[SHL]], [[SHR]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i64 [[X]], i64 [[OR]] -; CHECK-NEXT: ret i64 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = sub <2 x i64> zeroinitializer, [[SHAMT:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i64> [[SHAMT]], +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i64> [[TMP1]], +; CHECK-NEXT: [[TMP4:%.*]] = shl <2 x i64> [[X:%.*]], [[TMP2]] +; CHECK-NEXT: [[TMP5:%.*]] = lshr <2 x i64> [[X]], [[TMP3]] +; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[TMP4]], [[TMP5]] +; CHECK-NEXT: ret <2 x i64> [[R]] +; + %cmp = icmp eq <2 x i64> %shamt, zeroinitializer + %sub = sub <2 x i64> , %shamt + %shr = lshr <2 x i64> %x, %sub + %shl = shl <2 x i64> %x, %shamt + %or = or <2 x i64> %shl, %shr + %r = select <2 x i1> %cmp, <2 x i64> %x, <2 x i64> %or + ret <2 x i64> %r +} + +; Negative test - the transform is only valid with power-of-2 types. + +define i24 @rotl_select_weird_type(i24 %x, i24 %shamt) { +; CHECK-LABEL: @rotl_select_weird_type( +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i24 [[SHAMT:%.*]], 0 +; CHECK-NEXT: [[SUB:%.*]] = sub i24 24, [[SHAMT]] +; CHECK-NEXT: [[SHR:%.*]] = lshr i24 [[X:%.*]], [[SUB]] +; CHECK-NEXT: [[SHL:%.*]] = shl i24 [[X]], [[SHAMT]] +; CHECK-NEXT: [[OR:%.*]] = or i24 [[SHL]], [[SHR]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i24 [[X]], i24 [[OR]] +; CHECK-NEXT: ret i24 [[R]] ; - %cmp = icmp eq i64 %shamt, 0 - %sub = sub i64 64, %shamt - %shr = lshr i64 %x, %sub - %shl = shl i64 %x, %shamt - %or = or i64 %shl, %shr - %r = select i1 %cmp, i64 %x, i64 %or - ret i64 %r + %cmp = icmp eq i24 %shamt, 0 + %sub = sub i24 24, %shamt + %shr = lshr i24 %x, %sub + %shl = shl i24 %x, %shamt + %or = or i24 %shl, %shr + %r = select i1 %cmp, i24 %x, i24 %or + ret i24 %r } -- 2.11.0