#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DIBuilder.h"
+#include "llvm/IR/DataLayout.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <numeric>
using namespace llvm;
using namespace PatternMatch;
/// If we find a cast of an allocation instruction, try to eliminate the cast by
/// moving the type information into the alloc.
-Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
- AllocaInst &AI) {
+Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI,
+ AllocaInst &AI) {
PointerType *PTy = cast<PointerType>(CI.getType());
- BuilderTy AllocaBuilder(Builder);
- AllocaBuilder.SetInsertPoint(&AI);
+ IRBuilderBase::InsertPointGuard Guard(Builder);
+ Builder.SetInsertPoint(&AI);
// Get the type really allocated and the type casted to.
Type *AllocElTy = AI.getAllocatedType();
Type *CastElTy = PTy->getElementType();
if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr;
- unsigned AllocElTyAlign = DL.getABITypeAlignment(AllocElTy);
- unsigned CastElTyAlign = DL.getABITypeAlignment(CastElTy);
+ // This optimisation does not work for cases where the cast type
+ // is scalable and the allocated type is not. This because we need to
+ // know how many times the casted type fits into the allocated type.
+ // For the opposite case where the allocated type is scalable and the
+ // cast type is not this leads to poor code quality due to the
+ // introduction of 'vscale' into the calculations. It seems better to
+ // bail out for this case too until we've done a proper cost-benefit
+ // analysis.
+ bool AllocIsScalable = isa<ScalableVectorType>(AllocElTy);
+ bool CastIsScalable = isa<ScalableVectorType>(CastElTy);
+ if (AllocIsScalable != CastIsScalable) return nullptr;
+
+ Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy);
+ Align CastElTyAlign = DL.getABITypeAlign(CastElTy);
if (CastElTyAlign < AllocElTyAlign) return nullptr;
// If the allocation has multiple uses, only promote it if we are strictly
// same, we open the door to infinite loops of various kinds.
if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr;
- uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy);
- uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy);
+ // The alloc and cast types should be either both fixed or both scalable.
+ uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize();
+ uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize();
if (CastElTySize == 0 || AllocElTySize == 0) return nullptr;
// If the allocation has multiple uses, only promote it if we're not
// shrinking the amount of memory being allocated.
- uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy);
- uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy);
+ uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize();
+ uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize();
if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr;
// See if we can satisfy the modulus by pulling a scale out of the array
if ((AllocElTySize*ArraySizeScale) % CastElTySize != 0 ||
(AllocElTySize*ArrayOffset ) % CastElTySize != 0) return nullptr;
+ // We don't currently support arrays of scalable types.
+ assert(!AllocIsScalable || (ArrayOffset == 1 && ArraySizeScale == 0));
+
unsigned Scale = (AllocElTySize*ArraySizeScale)/CastElTySize;
Value *Amt = nullptr;
if (Scale == 1) {
} else {
Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale);
// Insert before the alloca, not before the cast.
- Amt = AllocaBuilder.CreateMul(Amt, NumElements);
+ Amt = Builder.CreateMul(Amt, NumElements);
}
if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) {
Value *Off = ConstantInt::get(AI.getArraySize()->getType(),
Offset, true);
- Amt = AllocaBuilder.CreateAdd(Amt, Off);
+ Amt = Builder.CreateAdd(Amt, Off);
}
- AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt);
- New->setAlignment(AI.getAlignment());
+ AllocaInst *New = Builder.CreateAlloca(CastElTy, Amt);
+ New->setAlignment(AI.getAlign());
New->takeName(&AI);
New->setUsedWithInAlloca(AI.isUsedWithInAlloca());
if (!AI.hasOneUse()) {
// New is the allocation instruction, pointer typed. AI is the original
// allocation instruction, also pointer typed. Thus, cast to use is BitCast.
- Value *NewCast = AllocaBuilder.CreateBitCast(New, AI.getType(), "tmpcast");
+ Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast");
replaceInstUsesWith(AI, NewCast);
+ eraseInstFromFunction(AI);
}
return replaceInstUsesWith(CI, New);
}
/// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns
/// true for, actually insert the code to evaluate the expression.
-Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty,
- bool isSigned) {
+Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
+ bool isSigned) {
if (Constant *C = dyn_cast<Constant>(V)) {
C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/);
// If we got a constantexpr back, try to simplify it with DL info.
- if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI))
- C = FoldedC;
- return C;
+ return ConstantFoldConstant(C, DL, &TLI);
}
// Otherwise, it must be an instruction.
return InsertNewInstWith(Res, *I);
}
-Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1,
- const CastInst *CI2) {
+Instruction::CastOps
+InstCombinerImpl::isEliminableCastPair(const CastInst *CI1,
+ const CastInst *CI2) {
Type *SrcTy = CI1->getSrcTy();
Type *MidTy = CI1->getDestTy();
Type *DstTy = CI2->getDestTy();
}
/// Implement the transforms common to all CastInst visitors.
-Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
+Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
// Try to eliminate a cast of a cast.
}
if (auto *Sel = dyn_cast<SelectInst>(Src)) {
- // We are casting a select. Try to fold the cast into the select, but only
- // if the select does not have a compare instruction with matching operand
- // types. Creating a select with operands that are different sizes than its
+ // We are casting a select. Try to fold the cast into the select if the
+ // select does not have a compare instruction with matching operand types
+ // or the select is likely better done in a narrow type.
+ // Creating a select with operands that are different sizes than its
// condition may inhibit other folds and lead to worse codegen.
auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition());
- if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType())
+ if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() ||
+ (CI.getOpcode() == Instruction::Trunc &&
+ shouldChangeType(CI.getSrcTy(), CI.getType()))) {
if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
return NV;
}
+ }
}
// If we are casting a PHI, then fold the cast into the PHI.
// Don't do this if it would create a PHI node with an illegal type from a
// legal type.
if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() ||
- shouldChangeType(CI.getType(), Src->getType()))
+ shouldChangeType(CI.getSrcTy(), CI.getType()))
if (Instruction *NV = foldOpIntoPhi(CI, PN))
return NV;
}
///
/// This function works on both vectors and scalars.
///
-static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
+static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
Instruction *CxtI) {
if (canAlwaysEvaluateInType(V, Ty))
return true;
break;
}
case Instruction::Shl: {
- // If we are truncating the result of this SHL, and if it's a shift of a
- // constant amount, we can always perform a SHL in a smaller type.
- const APInt *Amt;
- if (match(I->getOperand(1), m_APInt(Amt))) {
- uint32_t BitWidth = Ty->getScalarSizeInBits();
- if (Amt->getLimitedValue(BitWidth) < BitWidth)
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
- }
+ // If we are truncating the result of this SHL, and if it's a shift of an
+ // inrange amount, we can always perform a SHL in a smaller type.
+ uint32_t BitWidth = Ty->getScalarSizeInBits();
+ KnownBits AmtKnownBits =
+ llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+ if (AmtKnownBits.getMaxValue().ult(BitWidth))
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::LShr: {
// If this is a truncate of a logical shr, we can truncate it to a smaller
// lshr iff we know that the bits we would otherwise be shifting in are
// already zeros.
- const APInt *Amt;
- if (match(I->getOperand(1), m_APInt(Amt))) {
- uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
- uint32_t BitWidth = Ty->getScalarSizeInBits();
- if (Amt->getLimitedValue(BitWidth) < BitWidth &&
- IC.MaskedValueIsZero(I->getOperand(0),
- APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) {
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
- }
+ // TODO: It is enough to check that the bits we would be shifting in are
+ // zero - use AmtKnownBits.getMaxValue().
+ uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+ uint32_t BitWidth = Ty->getScalarSizeInBits();
+ KnownBits AmtKnownBits =
+ llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+ APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
+ if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+ IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) {
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
}
break;
}
// original type and the sign bit of the truncate type are similar.
// TODO: It is enough to check that the bits we would be shifting in are
// similar to sign bit of the truncate type.
- const APInt *Amt;
- if (match(I->getOperand(1), m_APInt(Amt))) {
- uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
- uint32_t BitWidth = Ty->getScalarSizeInBits();
- if (Amt->getLimitedValue(BitWidth) < BitWidth &&
- OrigBitWidth - BitWidth <
- IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
- return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
- }
+ uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
+ uint32_t BitWidth = Ty->getScalarSizeInBits();
+ KnownBits AmtKnownBits =
+ llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout());
+ unsigned ShiftedBits = OrigBitWidth - BitWidth;
+ if (AmtKnownBits.getMaxValue().ult(BitWidth) &&
+ ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI))
+ return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+ canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
break;
}
case Instruction::Trunc:
/// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32
/// --->
/// extractelement <4 x i32> %X, 1
-static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) {
+static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
+ InstCombinerImpl &IC) {
Value *TruncOp = Trunc.getOperand(0);
Type *DestType = Trunc.getType();
if (!TruncOp->hasOneUse() || !isa<IntegerType>(DestType))
// bitcast it to a vector type that we can extract from.
unsigned NumVecElts = VecWidth / DestWidth;
if (VecType->getElementType() != DestType) {
- VecType = VectorType::get(DestType, NumVecElts);
+ VecType = FixedVectorType::get(DestType, NumVecElts);
VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc");
}
return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
}
-/// Rotate left/right may occur in a wider type than necessary because of type
-/// promotion rules. Try to narrow the inputs and convert to funnel shift.
-Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
+/// Funnel/Rotate left/right may occur in a wider type than necessary because of
+/// type promotion rules. Try to narrow the inputs and convert to funnel shift.
+Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
assert((isa<VectorType>(Trunc.getSrcTy()) ||
shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
"Don't narrow to an illegal scalar type");
if (!isPowerOf2_32(NarrowWidth))
return nullptr;
- // First, find an or'd pair of opposite shifts with the same shifted operand:
- // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1))
- Value *Or0, *Or1;
- if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+ // First, find an or'd pair of opposite shifts:
+ // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1))
+ BinaryOperator *Or0, *Or1;
+ if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_BinOp(Or0), m_BinOp(Or1)))))
return nullptr;
- Value *ShVal, *ShAmt0, *ShAmt1;
- if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
- !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
+ Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1;
+ if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) ||
+ !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) ||
+ Or0->getOpcode() == Or1->getOpcode())
return nullptr;
- auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
- auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
- if (ShiftOpcode0 == ShiftOpcode1)
- return nullptr;
+ // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)).
+ if (Or0->getOpcode() == BinaryOperator::LShr) {
+ std::swap(Or0, Or1);
+ std::swap(ShVal0, ShVal1);
+ std::swap(ShAmt0, ShAmt1);
+ }
+ assert(Or0->getOpcode() == BinaryOperator::Shl &&
+ Or1->getOpcode() == BinaryOperator::LShr &&
+ "Illegal or(shift,shift) pair");
- // Match the shift amount operands for a rotate pattern. This always matches
- // a subtraction on the R operand.
- auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * {
+ // Match the shift amount operands for a funnel/rotate pattern. This always
+ // matches a subtraction on the R operand.
+ auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * {
// The shift amounts may add up to the narrow bit width:
- // (shl ShVal, L) | (lshr ShVal, Width - L)
+ // (shl ShVal0, L) | (lshr ShVal1, Width - L)
if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L)))))
return L;
+ // The following patterns currently only work for rotation patterns.
+ // TODO: Add more general funnel-shift compatible patterns.
+ if (ShVal0 != ShVal1)
+ return nullptr;
+
// The shift amount may be masked with negation:
- // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1)))
+ // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1)))
Value *X;
unsigned Mask = Width - 1;
if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) &&
};
Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth);
- bool SubIsOnLHS = false;
+ bool IsFshl = true; // Sub on LSHR.
if (!ShAmt) {
ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth);
- SubIsOnLHS = true;
+ IsFshl = false; // Sub on SHL.
}
if (!ShAmt)
return nullptr;
// will be a zext, but it could also be the result of an 'and' or 'shift'.
unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
- if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc))
+ if (!MaskedValueIsZero(ShVal0, HiBitMask, 0, &Trunc) ||
+ !MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc))
return nullptr;
// We have an unnecessarily wide rotate!
- // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt))
+ // trunc (or (lshr ShVal0, ShAmt), (shl ShVal1, BitWidth - ShAmt))
// Narrow the inputs and convert to funnel shift intrinsic:
// llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt))
Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
- Value *X = Builder.CreateTrunc(ShVal, DestTy);
- bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) ||
- (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl);
+ Value *X, *Y;
+ X = Y = Builder.CreateTrunc(ShVal0, DestTy);
+ if (ShVal0 != ShVal1)
+ Y = Builder.CreateTrunc(ShVal1, DestTy);
Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr;
Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy);
- return IntrinsicInst::Create(F, { X, X, NarrowShAmt });
+ return IntrinsicInst::Create(F, {X, Y, NarrowShAmt});
}
/// Try to narrow the width of math or bitwise logic instructions by pulling a
/// truncate ahead of binary operators.
/// TODO: Transforms for truncated shifts should be moved into here.
-Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
+Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) {
Type *SrcTy = Trunc.getSrcTy();
Type *DestTy = Trunc.getType();
if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
default: break;
}
- if (Instruction *NarrowOr = narrowRotate(Trunc))
+ if (Instruction *NarrowOr = narrowFunnelShift(Trunc))
return NarrowOr;
return nullptr;
InstCombiner::BuilderTy &Builder) {
auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) &&
- Shuf->getMask()->getSplatValue() &&
+ is_splat(Shuf->getShuffleMask()) &&
Shuf->getType() == Shuf->getOperand(0)->getType()) {
// trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask
Constant *NarrowUndef = UndefValue::get(Trunc.getType());
Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType());
- return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask());
+ return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask());
}
return nullptr;
return nullptr;
}
-static Instruction *narrowLoad(TruncInst &Trunc,
- InstCombiner::BuilderTy &Builder,
- const DataLayout &DL) {
- // Check the layout to ensure we are not creating an unsupported operation.
- // TODO: Create a GEP to offset the load?
- if (!DL.isLittleEndian())
- return nullptr;
- unsigned NarrowBitWidth = Trunc.getDestTy()->getPrimitiveSizeInBits();
- if (!DL.isLegalInteger(NarrowBitWidth))
- return nullptr;
-
- // Match a truncated load with no other uses.
- Value *X;
- if (!match(Trunc.getOperand(0), m_OneUse(m_Load(m_Value(X)))))
- return nullptr;
- LoadInst *WideLoad = cast<LoadInst>(Trunc.getOperand(0));
- if (!WideLoad->isSimple())
- return nullptr;
-
- // Don't narrow this load if we would lose information about the
- // dereferenceable range.
- bool CanBeNull;
- uint64_t DerefBits = X->getPointerDereferenceableBytes(DL, CanBeNull) * 8;
- if (DerefBits < WideLoad->getType()->getPrimitiveSizeInBits())
- return nullptr;
-
- // trunc (load X) --> load (bitcast X)
- PointerType *PtrTy = PointerType::get(Trunc.getDestTy(),
- WideLoad->getPointerAddressSpace());
- Value *Bitcast = Builder.CreatePointerCast(X, PtrTy);
- LoadInst *NarrowLoad = new LoadInst(Trunc.getDestTy(), Bitcast);
- NarrowLoad->setAlignment(WideLoad->getAlignment());
- copyMetadataForLoad(*NarrowLoad, *WideLoad);
- return NarrowLoad;
-}
-
-Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
- if (Instruction *Result = commonCastTransforms(CI))
+Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
+ if (Instruction *Result = commonCastTransforms(Trunc))
return Result;
- Value *Src = CI.getOperand(0);
- Type *DestTy = CI.getType(), *SrcTy = Src->getType();
+ Value *Src = Trunc.getOperand(0);
+ Type *DestTy = Trunc.getType(), *SrcTy = Src->getType();
+ unsigned DestWidth = DestTy->getScalarSizeInBits();
+ unsigned SrcWidth = SrcTy->getScalarSizeInBits();
// Attempt to truncate the entire input expression tree to the destination
// type. Only do this if the dest type is a simple type, don't convert the
// expression tree to something weird like i93 unless the source is also
// strange.
if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
- canEvaluateTruncated(Src, DestTy, *this, &CI)) {
+ canEvaluateTruncated(Src, DestTy, *this, &Trunc)) {
// If this cast is a truncate, evaluting in a different type always
// eliminates the cast, so it is always a win.
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
" to avoid cast: "
- << CI << '\n');
+ << Trunc << '\n');
Value *Res = EvaluateInDifferentType(Src, DestTy, false);
assert(Res->getType() == DestTy);
- return replaceInstUsesWith(CI, Res);
+ return replaceInstUsesWith(Trunc, Res);
+ }
+
+ // For integer types, check if we can shorten the entire input expression to
+ // DestWidth * 2, which won't allow removing the truncate, but reducing the
+ // width may enable further optimizations, e.g. allowing for larger
+ // vectorization factors.
+ if (auto *DestITy = dyn_cast<IntegerType>(DestTy)) {
+ if (DestWidth * 2 < SrcWidth) {
+ auto *NewDestTy = DestITy->getExtendedType();
+ if (shouldChangeType(SrcTy, NewDestTy) &&
+ canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) {
+ LLVM_DEBUG(
+ dbgs() << "ICE: EvaluateInDifferentType converting expression type"
+ " to reduce the width of operand of"
+ << Trunc << '\n');
+ Value *Res = EvaluateInDifferentType(Src, NewDestTy, false);
+ return new TruncInst(Res, DestTy);
+ }
+ }
}
// Test if the trunc is the user of a select which is part of a
// Even simplifying demanded bits can break the canonical form of a
// min/max.
Value *LHS, *RHS;
- if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0)))
- if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN)
+ if (SelectInst *Sel = dyn_cast<SelectInst>(Src))
+ if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN)
return nullptr;
// See if we can simplify any instructions used by the input whose sole
// purpose is to compute bits we don't care about.
- if (SimplifyDemandedInstructionBits(CI))
- return &CI;
+ if (SimplifyDemandedInstructionBits(Trunc))
+ return &Trunc;
- if (DestTy->getScalarSizeInBits() == 1) {
- Value *Zero = Constant::getNullValue(Src->getType());
+ if (DestWidth == 1) {
+ Value *Zero = Constant::getNullValue(SrcTy);
if (DestTy->isIntegerTy()) {
// Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only).
// TODO: We canonicalize to more instructions here because we are probably
// For vectors, we do not canonicalize all truncs to icmp, so optimize
// patterns that would be covered within visitICmpInst.
Value *X;
- const APInt *C;
- if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) {
+ Constant *C;
+ if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) {
// trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0
- APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C);
- Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC));
+ Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1));
+ Constant *MaskC = ConstantExpr::getShl(One, C);
+ Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
}
- if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)),
+ if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)),
m_Deferred(X))))) {
// trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0
- APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1;
- Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC));
+ Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1));
+ Constant *MaskC = ConstantExpr::getShl(One, C);
+ MaskC = ConstantExpr::getOr(MaskC, One);
+ Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
}
}
- // FIXME: Maybe combine the next two transforms to handle the no cast case
- // more efficiently. Support vector types. Cleanup code by using m_OneUse.
-
- // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion.
- Value *A = nullptr; ConstantInt *Cst = nullptr;
- if (Src->hasOneUse() &&
- match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) {
- // We have three types to worry about here, the type of A, the source of
- // the truncate (MidSize), and the destination of the truncate. We know that
- // ASize < MidSize and MidSize > ResultSize, but don't know the relation
- // between ASize and ResultSize.
- unsigned ASize = A->getType()->getPrimitiveSizeInBits();
-
- // If the shift amount is larger than the size of A, then the result is
- // known to be zero because all the input bits got shifted out.
- if (Cst->getZExtValue() >= ASize)
- return replaceInstUsesWith(CI, Constant::getNullValue(DestTy));
-
- // Since we're doing an lshr and a zero extend, and know that the shift
- // amount is smaller than ASize, it is always safe to do the shift in A's
- // type, then zero extend or truncate to the result.
- Value *Shift = Builder.CreateLShr(A, Cst->getZExtValue());
- Shift->takeName(Src);
- return CastInst::CreateIntegerCast(Shift, DestTy, false);
- }
-
- // FIXME: We should canonicalize to zext/trunc and remove this transform.
- // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type
- // conversion.
- // It works because bits coming from sign extension have the same value as
- // the sign bit of the original value; performing ashr instead of lshr
- // generates bits of the same value as the sign bit.
- if (Src->hasOneUse() &&
- match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) {
- Value *SExt = cast<Instruction>(Src)->getOperand(0);
- const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits();
- const unsigned ASize = A->getType()->getPrimitiveSizeInBits();
- const unsigned CISize = CI.getType()->getPrimitiveSizeInBits();
- const unsigned MaxAmt = SExtSize - std::max(CISize, ASize);
- unsigned ShiftAmt = Cst->getZExtValue();
-
- // This optimization can be only performed when zero bits generated by
- // the original lshr aren't pulled into the value after truncation, so we
- // can only shift by values no larger than the number of extension bits.
- // FIXME: Instead of bailing when the shift is too large, use and to clear
- // the extra bits.
- if (ShiftAmt <= MaxAmt) {
- if (CISize == ASize)
- return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(),
- std::min(ShiftAmt, ASize - 1)));
- if (SExt->hasOneUse()) {
- Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1));
- Shift->takeName(Src);
- return CastInst::CreateIntegerCast(Shift, CI.getType(), true);
+ Value *A;
+ Constant *C;
+ if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) {
+ unsigned AWidth = A->getType()->getScalarSizeInBits();
+ unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth);
+ auto *OldSh = cast<Instruction>(Src);
+ bool IsExact = OldSh->isExact();
+
+ // If the shift is small enough, all zero bits created by the shift are
+ // removed by the trunc.
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
+ APInt(SrcWidth, MaxShiftAmt)))) {
+ // trunc (lshr (sext A), C) --> ashr A, C
+ if (A->getType() == DestTy) {
+ Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false);
+ Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+ ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ ShAmt = Constant::mergeUndefsWith(ShAmt, C);
+ return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt)
+ : BinaryOperator::CreateAShr(A, ShAmt);
}
+ // The types are mismatched, so create a cast after shifting:
+ // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C)
+ if (Src->hasOneUse()) {
+ Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false);
+ Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt);
+ ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType());
+ Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact);
+ return CastInst::CreateIntegerCast(Shift, DestTy, true);
+ }
+ }
+ // TODO: Mask high bits with 'and'.
+ }
+
+ // trunc (*shr (trunc A), C) --> trunc(*shr A, C)
+ if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) {
+ unsigned MaxShiftAmt = SrcWidth - DestWidth;
+
+ // If the shift is small enough, all zero/sign bits created by the shift are
+ // removed by the trunc.
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE,
+ APInt(SrcWidth, MaxShiftAmt)))) {
+ auto *OldShift = cast<Instruction>(Src);
+ bool IsExact = OldShift->isExact();
+ auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true);
+ ShAmt = Constant::mergeUndefsWith(ShAmt, C);
+ Value *Shift =
+ OldShift->getOpcode() == Instruction::AShr
+ ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact)
+ : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact);
+ return CastInst::CreateTruncOrBitCast(Shift, DestTy);
}
}
- if (Instruction *I = narrowBinOp(CI))
+ if (Instruction *I = narrowBinOp(Trunc))
return I;
- if (Instruction *I = shrinkSplatShuffle(CI, Builder))
+ if (Instruction *I = shrinkSplatShuffle(Trunc, Builder))
return I;
- if (Instruction *I = shrinkInsertElt(CI, Builder))
+ if (Instruction *I = shrinkInsertElt(Trunc, Builder))
return I;
- if (Src->hasOneUse() && isa<IntegerType>(SrcTy) &&
- shouldChangeType(SrcTy, DestTy)) {
+ if (Src->hasOneUse() &&
+ (isa<VectorType>(SrcTy) || shouldChangeType(SrcTy, DestTy))) {
// Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the
// dest type is native and cst < dest size.
- if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) &&
+ if (match(Src, m_Shl(m_Value(A), m_Constant(C))) &&
!match(A, m_Shr(m_Value(), m_Constant()))) {
// Skip shifts of shift by constants. It undoes a combine in
// FoldShiftByConstant and is the extend in reg pattern.
- const unsigned DestSize = DestTy->getScalarSizeInBits();
- if (Cst->getValue().ult(DestSize)) {
+ APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth);
+ if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold))) {
Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr");
-
- return BinaryOperator::Create(
- Instruction::Shl, NewTrunc,
- ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize)));
+ return BinaryOperator::Create(Instruction::Shl, NewTrunc,
+ ConstantExpr::getTrunc(C, DestTy));
}
}
}
- if (Instruction *I = foldVecTruncToExtElt(CI, *this))
+ if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
return I;
- if (Instruction *NewLoad = narrowLoad(CI, Builder, DL))
- return NewLoad;
+ // Whenever an element is extracted from a vector, and then truncated,
+ // canonicalize by converting it to a bitcast followed by an
+ // extractelement.
+ //
+ // Example (little endian):
+ // trunc (extractelement <4 x i64> %X, 0) to i32
+ // --->
+ // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
+ Value *VecOp;
+ ConstantInt *Cst;
+ if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
+ auto *VecOpTy = cast<VectorType>(VecOp->getType());
+ auto VecElts = VecOpTy->getElementCount();
+
+ // A badly fit destination size would result in an invalid cast.
+ if (SrcWidth % DestWidth == 0) {
+ uint64_t TruncRatio = SrcWidth / DestWidth;
+ uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
+ uint64_t VecOpIdx = Cst->getZExtValue();
+ uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
+ : VecOpIdx * TruncRatio;
+ assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
+ "overflow 32-bits");
+
+ auto *BitCastTo =
+ VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
+ Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
+ return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
+ }
+ }
return nullptr;
}
-Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI,
- bool DoTransform) {
+Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext,
+ bool DoTransform) {
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
const APInt *Op1CV;
- if (match(ICI->getOperand(1), m_APInt(Op1CV))) {
+ if (match(Cmp->getOperand(1), m_APInt(Op1CV))) {
// zext (x <s 0) to i32 --> x>>u31 true if signbit set.
// zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear.
- if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
- (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
- if (!DoTransform) return ICI;
+ if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) ||
+ (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) {
+ if (!DoTransform) return Cmp;
- Value *In = ICI->getOperand(0);
+ Value *In = Cmp->getOperand(0);
Value *Sh = ConstantInt::get(In->getType(),
In->getType()->getScalarSizeInBits() - 1);
In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit");
- if (In->getType() != CI.getType())
- In = Builder.CreateIntCast(In, CI.getType(), false /*ZExt*/);
+ if (In->getType() != Zext.getType())
+ In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/);
- if (ICI->getPredicate() == ICmpInst::ICMP_SGT) {
+ if (Cmp->getPredicate() == ICmpInst::ICMP_SGT) {
Constant *One = ConstantInt::get(In->getType(), 1);
In = Builder.CreateXor(In, One, In->getName() + ".not");
}
- return replaceInstUsesWith(CI, In);
+ return replaceInstUsesWith(Zext, In);
}
// zext (X == 0) to i32 --> X^1 iff X has only the low bit set.
// zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) &&
// This only works for EQ and NE
- ICI->isEquality()) {
+ Cmp->isEquality()) {
// If Op1C some other power of two, convert:
- KnownBits Known = computeKnownBits(ICI->getOperand(0), 0, &CI);
+ KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
APInt KnownZeroMask(~Known.Zero);
if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
- if (!DoTransform) return ICI;
+ if (!DoTransform) return Cmp;
- bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE;
+ bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE;
if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) {
// (X&4) == 2 --> false
// (X&4) != 2 --> true
- Constant *Res = ConstantInt::get(CI.getType(), isNE);
- return replaceInstUsesWith(CI, Res);
+ Constant *Res = ConstantInt::get(Zext.getType(), isNE);
+ return replaceInstUsesWith(Zext, Res);
}
uint32_t ShAmt = KnownZeroMask.logBase2();
- Value *In = ICI->getOperand(0);
+ Value *In = Cmp->getOperand(0);
if (ShAmt) {
// Perform a logical shr by shiftamt.
// Insert the shift to put the result in the low bit.
In = Builder.CreateXor(In, One);
}
- if (CI.getType() == In->getType())
- return replaceInstUsesWith(CI, In);
+ if (Zext.getType() == In->getType())
+ return replaceInstUsesWith(Zext, In);
- Value *IntCast = Builder.CreateIntCast(In, CI.getType(), false);
- return replaceInstUsesWith(CI, IntCast);
+ Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false);
+ return replaceInstUsesWith(Zext, IntCast);
}
}
}
// icmp ne A, B is equal to xor A, B when A and B only really have one bit.
// It is also profitable to transform icmp eq into not(xor(A, B)) because that
// may lead to additional simplifications.
- if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) {
- if (IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) {
- Value *LHS = ICI->getOperand(0);
- Value *RHS = ICI->getOperand(1);
+ if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(0)->getType()) {
+ if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) {
+ Value *LHS = Cmp->getOperand(0);
+ Value *RHS = Cmp->getOperand(1);
- KnownBits KnownLHS = computeKnownBits(LHS, 0, &CI);
- KnownBits KnownRHS = computeKnownBits(RHS, 0, &CI);
+ KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext);
+ KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext);
if (KnownLHS.Zero == KnownRHS.Zero && KnownLHS.One == KnownRHS.One) {
APInt KnownBits = KnownLHS.Zero | KnownLHS.One;
APInt UnknownBit = ~KnownBits;
if (UnknownBit.countPopulation() == 1) {
- if (!DoTransform) return ICI;
+ if (!DoTransform) return Cmp;
Value *Result = Builder.CreateXor(LHS, RHS);
Result = Builder.CreateLShr(
Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros()));
- if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
+ if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1));
- Result->takeName(ICI);
- return replaceInstUsesWith(CI, Result);
+ Result->takeName(Cmp);
+ return replaceInstUsesWith(Zext, Result);
}
}
}
///
/// This function works on both vectors and scalars.
static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
- InstCombiner &IC, Instruction *CxtI) {
+ InstCombinerImpl &IC, Instruction *CxtI) {
BitsToClear = 0;
if (canAlwaysEvaluateInType(V, Ty))
return true;
}
}
-Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
+Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
// If this zero extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this zext.
if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
}
}
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
- return transformZExtICmp(ICI, CI);
+ if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src))
+ return transformZExtICmp(Cmp, CI);
BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src);
if (SrcI && SrcI->getOpcode() == Instruction::Or) {
ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0));
ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1));
if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() &&
+ LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType() &&
(transformZExtICmp(LHS, CI, false) ||
transformZExtICmp(RHS, CI, false))) {
// zext (or icmp, icmp) -> or (zext icmp), (zext icmp)
Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName());
Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName());
- BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast);
+ Value *Or = Builder.CreateOr(LCast, RCast, CI.getName());
+ if (auto *OrInst = dyn_cast<Instruction>(Or))
+ Builder.SetInsertPoint(OrInst);
// Perform the elimination.
if (auto *LZExt = dyn_cast<ZExtInst>(LCast))
if (auto *RZExt = dyn_cast<ZExtInst>(RCast))
transformZExtICmp(RHS, *RZExt);
- return Or;
+ return replaceInstUsesWith(CI, Or);
}
}
}
/// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp.
-Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) {
+Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI,
+ Instruction &CI) {
Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1);
ICmpInst::Predicate Pred = ICI->getPredicate();
return false;
}
-Instruction *InstCombiner::visitSExt(SExtInst &CI) {
+Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
// If this sign extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this sext.
if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
// for a truncate. If the source and dest are the same type, eliminate the
// trunc and extend and just do shifts. For example, turn:
// %a = trunc i32 %i to i8
- // %b = shl i8 %a, 6
- // %c = ashr i8 %b, 6
+ // %b = shl i8 %a, C
+ // %c = ashr i8 %b, C
// %d = sext i8 %c to i32
// into:
- // %a = shl i32 %i, 30
- // %d = ashr i32 %a, 30
+ // %a = shl i32 %i, 32-(8-C)
+ // %d = ashr i32 %a, 32-(8-C)
Value *A = nullptr;
// TODO: Eventually this could be subsumed by EvaluateInDifferentType.
- ConstantInt *BA = nullptr, *CA = nullptr;
- if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_ConstantInt(BA)),
- m_ConstantInt(CA))) &&
- BA == CA && A->getType() == CI.getType()) {
- unsigned MidSize = Src->getType()->getScalarSizeInBits();
- unsigned SrcDstSize = CI.getType()->getScalarSizeInBits();
- unsigned ShAmt = CA->getZExtValue()+SrcDstSize-MidSize;
- Constant *ShAmtV = ConstantInt::get(CI.getType(), ShAmt);
- A = Builder.CreateShl(A, ShAmtV, CI.getName());
- return BinaryOperator::CreateAShr(A, ShAmtV);
+ Constant *BA = nullptr, *CA = nullptr;
+ if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)),
+ m_Constant(CA))) &&
+ BA->isElementWiseEqual(CA) && A->getType() == DestTy) {
+ Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy);
+ Constant *NumLowbitsLeft = ConstantExpr::getSub(
+ ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt);
+ Constant *NewShAmt = ConstantExpr::getSub(
+ ConstantInt::get(DestTy, DestTy->getScalarSizeInBits()),
+ NumLowbitsLeft);
+ NewShAmt =
+ Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA);
+ A = Builder.CreateShl(A, NewShAmt, CI.getName());
+ return BinaryOperator::CreateAShr(A, NewShAmt);
}
return nullptr;
}
-
/// Return a Constant* for the specified floating-point constant if it fits
/// in the specified FP type without changing its value.
static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
// TODO: Make these support undef elements.
static Type *shrinkFPConstantVector(Value *V) {
auto *CV = dyn_cast<Constant>(V);
- if (!CV || !CV->getType()->isVectorTy())
+ auto *CVVTy = dyn_cast<VectorType>(V->getType());
+ if (!CV || !CVVTy)
return nullptr;
Type *MinType = nullptr;
- unsigned NumElts = CV->getType()->getVectorNumElements();
+ unsigned NumElts = cast<FixedVectorType>(CVVTy)->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
if (!CFP)
}
// Make a vector type from the minimal type.
- return VectorType::get(MinType, NumElts);
+ return FixedVectorType::get(MinType, NumElts);
}
/// Find the minimum FP type we can safely truncate to.
return V->getType();
}
-Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) {
+/// Return true if the cast from integer to FP can be proven to be exact for all
+/// possible inputs (the conversion does not lose any precision).
+static bool isKnownExactCastIntToFP(CastInst &I) {
+ CastInst::CastOps Opcode = I.getOpcode();
+ assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) &&
+ "Unexpected cast");
+ Value *Src = I.getOperand(0);
+ Type *SrcTy = Src->getType();
+ Type *FPTy = I.getType();
+ bool IsSigned = Opcode == Instruction::SIToFP;
+ int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned;
+
+ // Easy case - if the source integer type has less bits than the FP mantissa,
+ // then the cast must be exact.
+ int DestNumSigBits = FPTy->getFPMantissaWidth();
+ if (SrcSize <= DestNumSigBits)
+ return true;
+
+ // Cast from FP to integer and back to FP is independent of the intermediate
+ // integer width because of poison on overflow.
+ Value *F;
+ if (match(Src, m_FPToSI(m_Value(F))) || match(Src, m_FPToUI(m_Value(F)))) {
+ // If this is uitofp (fptosi F), the source needs an extra bit to avoid
+ // potential rounding of negative FP input values.
+ int SrcNumSigBits = F->getType()->getFPMantissaWidth();
+ if (!IsSigned && match(Src, m_FPToSI(m_Value())))
+ SrcNumSigBits++;
+
+ // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal
+ // significant bits than the destination (and make sure neither type is
+ // weird -- ppc_fp128).
+ if (SrcNumSigBits > 0 && DestNumSigBits > 0 &&
+ SrcNumSigBits <= DestNumSigBits)
+ return true;
+ }
+
+ // TODO:
+ // Try harder to find if the source integer type has less significant bits.
+ // For example, compute number of sign bits or compute low bit mask.
+ return false;
+}
+
+Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
if (Instruction *I = commonCastTransforms(FPT))
return I;
// what we can and cannot do safely varies from operation to operation, and
// is explained below in the various case statements.
Type *Ty = FPT.getType();
- BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0));
- if (OpI && OpI->hasOneUse()) {
- Type *LHSMinType = getMinimumFPType(OpI->getOperand(0));
- Type *RHSMinType = getMinimumFPType(OpI->getOperand(1));
- unsigned OpWidth = OpI->getType()->getFPMantissaWidth();
+ auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
+ if (BO && BO->hasOneUse()) {
+ Type *LHSMinType = getMinimumFPType(BO->getOperand(0));
+ Type *RHSMinType = getMinimumFPType(BO->getOperand(1));
+ unsigned OpWidth = BO->getType()->getFPMantissaWidth();
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();
unsigned SrcWidth = std::max(LHSWidth, RHSWidth);
unsigned DstWidth = Ty->getFPMantissaWidth();
- switch (OpI->getOpcode()) {
+ switch (BO->getOpcode()) {
default: break;
case Instruction::FAdd:
case Instruction::FSub:
// could be tightened for those cases, but they are rare (the main
// case of interest here is (float)((double)float + float)).
if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS);
- RI->copyFastMathFlags(OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS);
+ RI->copyFastMathFlags(BO);
return RI;
}
break;
// rounding can possibly occur; we can safely perform the operation
// in the destination format if it can represent both sources.
if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ return BinaryOperator::CreateFMulFMF(LHS, RHS, BO);
}
break;
case Instruction::FDiv:
// condition used here is a good conservative first pass.
// TODO: Tighten bound via rigorous analysis of the unbalanced case.
if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) {
- Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty);
- Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty);
- return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI);
+ Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty);
+ Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty);
+ return BinaryOperator::CreateFDivFMF(LHS, RHS, BO);
}
break;
case Instruction::FRem: {
break;
Value *LHS, *RHS;
if (LHSWidth == SrcWidth) {
- LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType);
- RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType);
+ LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType);
+ RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType);
} else {
- LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType);
- RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType);
+ LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType);
+ RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType);
}
- Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI);
+ Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO);
return CastInst::CreateFPCast(ExactResult, Ty);
}
}
Value *X;
Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0));
if (Op && Op->hasOneUse()) {
+ // FIXME: The FMF should propagate from the fptrunc, not the source op.
+ IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ if (isa<FPMathOperator>(Op))
+ Builder.setFastMathFlags(Op->getFastMathFlags());
+
if (match(Op, m_FNeg(m_Value(X)))) {
Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty);
- // FIXME: Once we're sure that unary FNeg optimizations are on par with
- // binary FNeg, this should always return a unary operator.
- if (isa<BinaryOperator>(Op))
- return BinaryOperator::CreateFNegFMF(InnerTrunc, Op);
return UnaryOperator::CreateFNegFMF(InnerTrunc, Op);
}
+
+ // If we are truncating a select that has an extended operand, we can
+ // narrow the other operand and do the select as a narrow op.
+ Value *Cond, *X, *Y;
+ if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) &&
+ X->getType() == Ty) {
+ // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y)
+ Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
+ Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op);
+ return replaceInstUsesWith(FPT, Sel);
+ }
+ if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) &&
+ X->getType() == Ty) {
+ // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X
+ Value *NarrowY = Builder.CreateFPTrunc(Y, Ty);
+ Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op);
+ return replaceInstUsesWith(FPT, Sel);
+ }
}
if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) {
case Intrinsic::nearbyint:
case Intrinsic::rint:
case Intrinsic::round:
+ case Intrinsic::roundeven:
case Intrinsic::trunc: {
Value *Src = II->getArgOperand(0);
if (!Src->hasOneUse())
if (Instruction *I = shrinkInsertElt(FPT, Builder))
return I;
+ Value *Src = FPT.getOperand(0);
+ if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) {
+ auto *FPCast = cast<CastInst>(Src);
+ if (isKnownExactCastIntToFP(*FPCast))
+ return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
+ }
+
return nullptr;
}
-Instruction *InstCombiner::visitFPExt(CastInst &CI) {
- return commonCastTransforms(CI);
+Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
+ // If the source operand is a cast from integer to FP and known exact, then
+ // cast the integer operand directly to the destination type.
+ Type *Ty = FPExt.getType();
+ Value *Src = FPExt.getOperand(0);
+ if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) {
+ auto *FPCast = cast<CastInst>(Src);
+ if (isKnownExactCastIntToFP(*FPCast))
+ return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
+ }
+
+ return commonCastTransforms(FPExt);
}
-// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X)
-// This is safe if the intermediate type has enough bits in its mantissa to
-// accurately represent all values of X. For example, this won't work with
-// i64 -> float -> i64.
-Instruction *InstCombiner::FoldItoFPtoI(Instruction &FI) {
+/// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X)
+/// This is safe if the intermediate type has enough bits in its mantissa to
+/// accurately represent all values of X. For example, this won't work with
+/// i64 -> float -> i64.
+Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) {
if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0)))
return nullptr;
- Instruction *OpI = cast<Instruction>(FI.getOperand(0));
- Value *SrcI = OpI->getOperand(0);
- Type *FITy = FI.getType();
- Type *OpITy = OpI->getType();
- Type *SrcTy = SrcI->getType();
- bool IsInputSigned = isa<SIToFPInst>(OpI);
+ auto *OpI = cast<CastInst>(FI.getOperand(0));
+ Value *X = OpI->getOperand(0);
+ Type *XType = X->getType();
+ Type *DestType = FI.getType();
bool IsOutputSigned = isa<FPToSIInst>(FI);
- // We can safely assume the conversion won't overflow the output range,
- // because (for example) (uint8_t)18293.f is undefined behavior.
-
// Since we can assume the conversion won't overflow, our decision as to
// whether the input will fit in the float should depend on the minimum
// of the input range and output range.
// This means this is also safe for a signed input and unsigned output, since
// a negative input would lead to undefined behavior.
- int InputSize = (int)SrcTy->getScalarSizeInBits() - IsInputSigned;
- int OutputSize = (int)FITy->getScalarSizeInBits() - IsOutputSigned;
- int ActualSize = std::min(InputSize, OutputSize);
-
- if (ActualSize <= OpITy->getFPMantissaWidth()) {
- if (FITy->getScalarSizeInBits() > SrcTy->getScalarSizeInBits()) {
- if (IsInputSigned && IsOutputSigned)
- return new SExtInst(SrcI, FITy);
- return new ZExtInst(SrcI, FITy);
- }
- if (FITy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits())
- return new TruncInst(SrcI, FITy);
- if (SrcTy == FITy)
- return replaceInstUsesWith(FI, SrcI);
- return new BitCastInst(SrcI, FITy);
+ if (!isKnownExactCastIntToFP(*OpI)) {
+ // The first cast may not round exactly based on the source integer width
+ // and FP width, but the overflow UB rules can still allow this to fold.
+ // If the destination type is narrow, that means the intermediate FP value
+ // must be large enough to hold the source value exactly.
+ // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior.
+ int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned;
+ if (OutputSize > OpI->getType()->getFPMantissaWidth())
+ return nullptr;
}
- return nullptr;
-}
-Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) {
- Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0));
- if (!OpI)
- return commonCastTransforms(FI);
+ if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) {
+ bool IsInputSigned = isa<SIToFPInst>(OpI);
+ if (IsInputSigned && IsOutputSigned)
+ return new SExtInst(X, DestType);
+ return new ZExtInst(X, DestType);
+ }
+ if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits())
+ return new TruncInst(X, DestType);
+
+ assert(XType == DestType && "Unexpected types for int to FP to int casts");
+ return replaceInstUsesWith(FI, X);
+}
- if (Instruction *I = FoldItoFPtoI(FI))
+Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) {
+ if (Instruction *I = foldItoFPtoI(FI))
return I;
return commonCastTransforms(FI);
}
-Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) {
- Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0));
- if (!OpI)
- return commonCastTransforms(FI);
-
- if (Instruction *I = FoldItoFPtoI(FI))
+Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) {
+ if (Instruction *I = foldItoFPtoI(FI))
return I;
return commonCastTransforms(FI);
}
-Instruction *InstCombiner::visitUIToFP(CastInst &CI) {
+Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitSIToFP(CastInst &CI) {
+Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) {
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
+Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) {
// If the source integer type is not the intptr_t type for this target, do a
// trunc or zext to the intptr_t type, then inttoptr of it. This allows the
// cast to be exposed to other transforms.
if (CI.getOperand(0)->getType()->getScalarSizeInBits() !=
DL.getPointerSizeInBits(AS)) {
Type *Ty = DL.getIntPtrType(CI.getContext(), AS);
- if (CI.getType()->isVectorTy()) // Handle vectors of pointers.
- Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements());
+ // Handle vectors of pointers.
+ if (auto *CIVTy = dyn_cast<VectorType>(CI.getType()))
+ Ty = VectorType::get(Ty, CIVTy->getElementCount());
Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty);
return new IntToPtrInst(P, CI.getType());
}
/// Implement the transforms for cast of pointer (bitcast/ptrtoint)
-Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) {
+Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) {
Value *Src = CI.getOperand(0);
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) {
// Changing the cast operand is usually not a good idea but it is safe
// here because the pointer operand is being replaced with another
// pointer operand so the opcode doesn't need to change.
- Worklist.Add(GEP);
- CI.setOperand(0, GEP->getOperand(0));
- return &CI;
+ return replaceOperand(CI, 0, GEP->getOperand(0));
}
}
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
+Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) {
// If the destination integer type is not the intptr_t type for this target,
// do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast
// to be exposed to other transforms.
-
+ Value *SrcOp = CI.getPointerOperand();
Type *Ty = CI.getType();
unsigned AS = CI.getPointerAddressSpace();
+ unsigned TySize = Ty->getScalarSizeInBits();
+ unsigned PtrSize = DL.getPointerSizeInBits(AS);
+ if (TySize != PtrSize) {
+ Type *IntPtrTy = DL.getIntPtrType(CI.getContext(), AS);
+ // Handle vectors of pointers.
+ if (auto *VecTy = dyn_cast<VectorType>(Ty))
+ IntPtrTy = VectorType::get(IntPtrTy, VecTy->getElementCount());
- if (Ty->getScalarSizeInBits() == DL.getIndexSizeInBits(AS))
- return commonPointerCastTransforms(CI);
+ Value *P = Builder.CreatePtrToInt(SrcOp, IntPtrTy);
+ return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
+ }
- Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS);
- if (Ty->isVectorTy()) // Handle vectors of pointers.
- PtrTy = VectorType::get(PtrTy, Ty->getVectorNumElements());
+ Value *Vec, *Scalar, *Index;
+ if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)),
+ m_Value(Scalar), m_Value(Index)))) &&
+ Vec->getType() == Ty) {
+ assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type");
+ // Convert the scalar to int followed by insert to eliminate one cast:
+ // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index
+ Value *NewCast = Builder.CreatePtrToInt(Scalar, Ty->getScalarType());
+ return InsertElementInst::Create(Vec, NewCast, Index);
+ }
- Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy);
- return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false);
+ return commonPointerCastTransforms(CI);
}
/// This input value (which is known to have vector type) is being zero extended
-/// or truncated to the specified vector type.
+/// or truncated to the specified vector type. Since the zext/trunc is done
+/// using an integer type, we have a (bitcast(cast(bitcast))) pattern,
+/// endianness will impact which end of the vector that is extended or
+/// truncated.
+///
+/// A vector is always stored with index 0 at the lowest address, which
+/// corresponds to the most significant bits for a big endian stored integer and
+/// the least significant bits for little endian. A trunc/zext of an integer
+/// impacts the big end of the integer. Thus, we need to add/remove elements at
+/// the front of the vector for big endian targets, and the back of the vector
+/// for little endian targets.
+///
/// Try to replace it with a shuffle (and vector/vector bitcast) if possible.
///
/// The source and destination vector types may have different element types.
-static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy,
- InstCombiner &IC) {
+static Instruction *
+optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy,
+ InstCombinerImpl &IC) {
// We can only do this optimization if the output is a multiple of the input
// element size, or the input is a multiple of the output element size.
// Convert the input type to have the same element type as the output.
DestTy->getElementType()->getPrimitiveSizeInBits())
return nullptr;
- SrcTy = VectorType::get(DestTy->getElementType(), SrcTy->getNumElements());
+ SrcTy =
+ FixedVectorType::get(DestTy->getElementType(),
+ cast<FixedVectorType>(SrcTy)->getNumElements());
InVal = IC.Builder.CreateBitCast(InVal, SrcTy);
}
+ bool IsBigEndian = IC.getDataLayout().isBigEndian();
+ unsigned SrcElts = cast<FixedVectorType>(SrcTy)->getNumElements();
+ unsigned DestElts = cast<FixedVectorType>(DestTy)->getNumElements();
+
+ assert(SrcElts != DestElts && "Element counts should be different.");
+
// Now that the element types match, get the shuffle mask and RHS of the
// shuffle to use, which depends on whether we're increasing or decreasing the
// size of the input.
- SmallVector<uint32_t, 16> ShuffleMask;
+ SmallVector<int, 16> ShuffleMaskStorage;
+ ArrayRef<int> ShuffleMask;
Value *V2;
- if (SrcTy->getNumElements() > DestTy->getNumElements()) {
- // If we're shrinking the number of elements, just shuffle in the low
- // elements from the input and use undef as the second shuffle input.
- V2 = UndefValue::get(SrcTy);
- for (unsigned i = 0, e = DestTy->getNumElements(); i != e; ++i)
- ShuffleMask.push_back(i);
+ // Produce an identify shuffle mask for the src vector.
+ ShuffleMaskStorage.resize(SrcElts);
+ std::iota(ShuffleMaskStorage.begin(), ShuffleMaskStorage.end(), 0);
+ if (SrcElts > DestElts) {
+ // If we're shrinking the number of elements (rewriting an integer
+ // truncate), just shuffle in the elements corresponding to the least
+ // significant bits from the input and use undef as the second shuffle
+ // input.
+ V2 = UndefValue::get(SrcTy);
+ // Make sure the shuffle mask selects the "least significant bits" by
+ // keeping elements from back of the src vector for big endian, and from the
+ // front for little endian.
+ ShuffleMask = ShuffleMaskStorage;
+ if (IsBigEndian)
+ ShuffleMask = ShuffleMask.take_back(DestElts);
+ else
+ ShuffleMask = ShuffleMask.take_front(DestElts);
} else {
- // If we're increasing the number of elements, shuffle in all of the
- // elements from InVal and fill the rest of the result elements with zeros
- // from a constant zero.
+ // If we're increasing the number of elements (rewriting an integer zext),
+ // shuffle in all of the elements from InVal. Fill the rest of the result
+ // elements with zeros from a constant zero.
V2 = Constant::getNullValue(SrcTy);
- unsigned SrcElts = SrcTy->getNumElements();
- for (unsigned i = 0, e = SrcElts; i != e; ++i)
- ShuffleMask.push_back(i);
-
- // The excess elements reference the first element of the zero input.
- for (unsigned i = 0, e = DestTy->getNumElements()-SrcElts; i != e; ++i)
- ShuffleMask.push_back(SrcElts);
- }
-
- return new ShuffleVectorInst(InVal, V2,
- ConstantDataVector::get(V2->getContext(),
- ShuffleMask));
+ // Use first elt from V2 when indicating zero in the shuffle mask.
+ uint32_t NullElt = SrcElts;
+ // Extend with null values in the "most significant bits" by adding elements
+ // in front of the src vector for big endian, and at the back for little
+ // endian.
+ unsigned DeltaElts = DestElts - SrcElts;
+ if (IsBigEndian)
+ ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt);
+ else
+ ShuffleMaskStorage.append(DeltaElts, NullElt);
+ ShuffleMask = ShuffleMaskStorage;
+ }
+
+ return new ShuffleVectorInst(InVal, V2, ShuffleMask);
}
static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) {
///
/// Into two insertelements that do "buildvector{%inc, %inc5}".
static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
- InstCombiner &IC) {
- VectorType *DestVecTy = cast<VectorType>(CI.getType());
+ InstCombinerImpl &IC) {
+ auto *DestVecTy = cast<FixedVectorType>(CI.getType());
Value *IntInput = CI.getOperand(0);
SmallVector<Value*, 8> Elements(DestVecTy->getNumElements());
/// vectors better than bitcasts of scalars because vector registers are
/// usually not type-specific like scalar integer or scalar floating-point.
static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast,
- InstCombiner &IC) {
+ InstCombinerImpl &IC) {
// TODO: Create and use a pattern matcher for ExtractElementInst.
auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0));
if (!ExtElt || !ExtElt->hasOneUse())
if (!VectorType::isValidElementType(DestType))
return nullptr;
- unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements();
- auto *NewVecType = VectorType::get(DestType, NumElts);
+ auto *NewVecType = VectorType::get(DestType, ExtElt->getVectorOperandType());
auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(),
NewVecType, "bc");
return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand());
if (match(BO->getOperand(1), m_Constant(C))) {
// bitcast (logic X, C) --> logic (bitcast X, C')
Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy);
- Value *CastedC = ConstantExpr::getBitCast(C, DestTy);
+ Value *CastedC = Builder.CreateBitCast(C, DestTy);
return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC);
}
// A vector select must maintain the same number of elements in its operands.
Type *CondTy = Cond->getType();
Type *DestTy = BitCast.getType();
- if (CondTy->isVectorTy()) {
- if (!DestTy->isVectorTy())
- return nullptr;
- if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements())
+ if (auto *CondVTy = dyn_cast<VectorType>(CondTy))
+ if (!DestTy->isVectorTy() ||
+ CondVTy->getElementCount() !=
+ cast<VectorType>(DestTy)->getElementCount())
return nullptr;
- }
// FIXME: This transform is restricted from changing the select between
// scalars and vectors to avoid backend problems caused by creating
///
/// All the related PHI nodes can be replaced by new PHI nodes with type A.
/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
-Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
+Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI,
+ PHINode *PN) {
// BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp.
if (hasStoreUsersOnly(CI))
return nullptr;
}
}
+ // Check that each user of each old PHI node is something that we can
+ // rewrite, so that all of the old PHI nodes can be cleaned up afterwards.
+ for (auto *OldPN : OldPhiNodes) {
+ for (User *V : OldPN->users()) {
+ if (auto *SI = dyn_cast<StoreInst>(V)) {
+ if (!SI->isSimple() || SI->getOperand(0) != OldPN)
+ return nullptr;
+ } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
+ // Verify it's a B->A cast.
+ Type *TyB = BCI->getOperand(0)->getType();
+ Type *TyA = BCI->getType();
+ if (TyA != DestTy || TyB != SrcTy)
+ return nullptr;
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ // As long as the user is another old PHI node, then even if we don't
+ // rewrite it, the PHI web we're considering won't have any users
+ // outside itself, so it'll be dead.
+ if (OldPhiNodes.count(PHI) == 0)
+ return nullptr;
+ } else {
+ return nullptr;
+ }
+ }
+ }
+
// For each old PHI node, create a corresponding new PHI node with a type A.
SmallDenseMap<PHINode *, PHINode *> NewPNodes;
for (auto *OldPN : OldPhiNodes) {
if (auto *C = dyn_cast<Constant>(V)) {
NewV = ConstantExpr::getBitCast(C, DestTy);
} else if (auto *LI = dyn_cast<LoadInst>(V)) {
- Builder.SetInsertPoint(LI->getNextNode());
- NewV = Builder.CreateBitCast(LI, DestTy);
- Worklist.Add(LI);
+ // Explicitly perform load combine to make sure no opposing transform
+ // can remove the bitcast in the meantime and trigger an infinite loop.
+ Builder.SetInsertPoint(LI);
+ NewV = combineLoadToNewType(*LI, DestTy);
+ // Remove the old load and its use in the old phi, which itself becomes
+ // dead once the whole transform finishes.
+ replaceInstUsesWith(*LI, UndefValue::get(LI->getType()));
+ eraseInstFromFunction(*LI);
} else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
NewV = BCI->getOperand(0);
} else if (auto *PrevPN = dyn_cast<PHINode>(V)) {
Instruction *RetVal = nullptr;
for (auto *OldPN : OldPhiNodes) {
PHINode *NewPN = NewPNodes[OldPN];
- for (User *V : OldPN->users()) {
+ for (User *V : make_early_inc_range(OldPN->users())) {
if (auto *SI = dyn_cast<StoreInst>(V)) {
- if (SI->isSimple() && SI->getOperand(0) == OldPN) {
- Builder.SetInsertPoint(SI);
- auto *NewBC =
- cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
- SI->setOperand(0, NewBC);
- Worklist.Add(SI);
- assert(hasStoreUsersOnly(*NewBC));
- }
+ assert(SI->isSimple() && SI->getOperand(0) == OldPN);
+ Builder.SetInsertPoint(SI);
+ auto *NewBC =
+ cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy));
+ SI->setOperand(0, NewBC);
+ Worklist.push(SI);
+ assert(hasStoreUsersOnly(*NewBC));
}
else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
- // Verify it's a B->A cast.
Type *TyB = BCI->getOperand(0)->getType();
Type *TyA = BCI->getType();
- if (TyA == DestTy && TyB == SrcTy) {
- Instruction *I = replaceInstUsesWith(*BCI, NewPN);
- if (BCI == &CI)
- RetVal = I;
- }
+ assert(TyA == DestTy && TyB == SrcTy);
+ (void) TyA;
+ (void) TyB;
+ Instruction *I = replaceInstUsesWith(*BCI, NewPN);
+ if (BCI == &CI)
+ RetVal = I;
+ } else if (auto *PHI = dyn_cast<PHINode>(V)) {
+ assert(OldPhiNodes.contains(PHI));
+ (void) PHI;
+ } else {
+ llvm_unreachable("all uses should be handled");
}
}
}
return RetVal;
}
-Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
+Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
// If the operands are integer typed then apply the integer transforms,
// otherwise just apply the common ones.
Value *Src = CI.getOperand(0);
if (DestTy == Src->getType())
return replaceInstUsesWith(CI, Src);
- if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) {
+ if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) {
PointerType *SrcPTy = cast<PointerType>(SrcTy);
+ PointerType *DstPTy = cast<PointerType>(DestTy);
Type *DstElTy = DstPTy->getElementType();
Type *SrcElTy = SrcPTy->getElementType();
// to a getelementptr X, 0, 0, 0... turn it into the appropriate gep.
// This can enhance SROA and other transforms that want type-safe pointers.
unsigned NumZeros = 0;
- while (SrcElTy != DstElTy &&
- isa<CompositeType>(SrcElTy) && !SrcElTy->isPointerTy() &&
- SrcElTy->getNumContainedTypes() /* not "{}" */) {
- SrcElTy = cast<CompositeType>(SrcElTy)->getTypeAtIndex(0U);
+ while (SrcElTy && SrcElTy != DstElTy) {
+ SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0);
++NumZeros;
}
// If we found a path from the src to dest, create the getelementptr now.
if (SrcElTy == DstElTy) {
SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0));
- return GetElementPtrInst::CreateInBounds(SrcPTy->getElementType(), Src,
- Idxs);
+ GetElementPtrInst *GEP =
+ GetElementPtrInst::Create(SrcPTy->getElementType(), Src, Idxs);
+
+ // If the source pointer is dereferenceable, then assume it points to an
+ // allocated object and apply "inbounds" to the GEP.
+ bool CanBeNull;
+ if (Src->getPointerDereferenceableBytes(DL, CanBeNull)) {
+ // In a non-default address space (not 0), a null pointer can not be
+ // assumed inbounds, so ignore that case (dereferenceable_or_null).
+ // The reason is that 'null' is not treated differently in these address
+ // spaces, and we consequently ignore the 'gep inbounds' special case
+ // for 'null' which allows 'inbounds' on 'null' if the indices are
+ // zeros.
+ if (SrcPTy->getAddressSpace() == 0 || !CanBeNull)
+ GEP->setIsInBounds();
+ }
+ return GEP;
}
}
- if (VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {
- if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) {
+ if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) {
+ // Beware: messing with this target-specific oddity may cause trouble.
+ if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) {
Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType());
return InsertElementInst::Create(UndefValue::get(DestTy), Elem,
Constant::getNullValue(Type::getInt32Ty(CI.getContext())));
- // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast)
}
if (isa<IntegerType>(SrcTy)) {
CastInst *SrcCast = cast<CastInst>(Src);
if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0)))
if (isa<VectorType>(BCIn->getOperand(0)->getType()))
- if (Instruction *I = optimizeVectorResize(BCIn->getOperand(0),
- cast<VectorType>(DestTy), *this))
+ if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts(
+ BCIn->getOperand(0), cast<VectorType>(DestTy), *this))
return I;
}
}
}
- if (VectorType *SrcVTy = dyn_cast<VectorType>(SrcTy)) {
+ if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(SrcTy)) {
if (SrcVTy->getNumElements() == 1) {
// If our destination is not a vector, then make this a straight
// scalar-scalar cast.
}
}
- if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) {
+ if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) {
// Okay, we have (bitcast (shuffle ..)). Check to see if this is
// a bitcast to a vector with the same # elts.
- if (SVI->hasOneUse() && DestTy->isVectorTy() &&
- DestTy->getVectorNumElements() == SVI->getType()->getNumElements() &&
- SVI->getType()->getNumElements() ==
- SVI->getOperand(0)->getType()->getVectorNumElements()) {
+ Value *ShufOp0 = Shuf->getOperand(0);
+ Value *ShufOp1 = Shuf->getOperand(1);
+ auto ShufElts = cast<VectorType>(Shuf->getType())->getElementCount();
+ auto SrcVecElts = cast<VectorType>(ShufOp0->getType())->getElementCount();
+ if (Shuf->hasOneUse() && DestTy->isVectorTy() &&
+ cast<VectorType>(DestTy)->getElementCount() == ShufElts &&
+ ShufElts == SrcVecElts) {
BitCastInst *Tmp;
// If either of the operands is a cast from CI.getType(), then
// evaluating the shuffle in the casted destination's type will allow
// us to eliminate at least one cast.
- if (((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(0))) &&
+ if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) &&
Tmp->getOperand(0)->getType() == DestTy) ||
- ((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(1))) &&
+ ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) &&
Tmp->getOperand(0)->getType() == DestTy)) {
- Value *LHS = Builder.CreateBitCast(SVI->getOperand(0), DestTy);
- Value *RHS = Builder.CreateBitCast(SVI->getOperand(1), DestTy);
+ Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy);
+ Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy);
// Return a new shuffle vector. Use the same element ID's, as we
// know the vector types match #elts.
- return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2));
+ return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask());
}
}
+
+ // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as
+ // a byte-swap:
+ // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X)
+ // TODO: We should match the related pattern for bitreverse.
+ if (DestTy->isIntegerTy() &&
+ DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
+ SrcTy->getScalarSizeInBits() == 8 &&
+ ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() &&
+ Shuf->isReverse()) {
+ assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
+ assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op");
+ Function *Bswap =
+ Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy);
+ Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy);
+ return IntrinsicInst::Create(Bswap, { ScalarX });
+ }
}
// Handle the A->B->A cast, and there is an intervening PHI node.
return commonCastTransforms(CI);
}
-Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
+Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
// If the destination pointer element type is not the same as the source's
// first do a bitcast to the destination type, and then the addrspacecast.
// This allows the cast to be exposed to other transforms.
Type *DestElemTy = DestTy->getElementType();
if (SrcTy->getElementType() != DestElemTy) {
Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace());
- if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) {
- // Handle vectors of pointers.
- MidTy = VectorType::get(MidTy, VT->getNumElements());
- }
+ // Handle vectors of pointers.
+ if (VectorType *VT = dyn_cast<VectorType>(CI.getType()))
+ MidTy = VectorType::get(MidTy, VT->getElementCount());
Value *NewBitCast = Builder.CreateBitCast(Src, MidTy);
return new AddrSpaceCastInst(NewBitCast, CI.getType());