if (ShiftOpcode0 == ShiftOpcode1)
return nullptr;
- // The shift amounts must add up to the narrow bit width.
- Value *ShAmt;
- bool SubIsOnLHS;
+ // Match the shift amount operands for a rotate pattern. This always matches
+ // a subtraction on the R operand.
+ auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * {
+ // The shift amounts may add up to the narrow bit width:
+ // (shl ShVal, L) | (lshr ShVal, Width - L)
+ if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L)))))
+ return L;
+
+ return nullptr;
+ };
+
Type *DestTy = Trunc.getType();
unsigned NarrowWidth = DestTy->getScalarSizeInBits();
- if (match(ShAmt0,
- m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) {
- ShAmt = ShAmt1;
+ Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth);
+ bool SubIsOnLHS = false;
+ if (!ShAmt) {
+ ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth);
SubIsOnLHS = true;
- } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth),
- m_Specific(ShAmt0))))) {
- ShAmt = ShAmt0;
- SubIsOnLHS = false;
- } else {
- return nullptr;
}
+ if (!ShAmt)
+ return nullptr;
// The shifted value must have high zeros in the wide type. Typically, this
// will be a zext, but it could also be the result of an 'and' or 'shift'.