--- /dev/null
+//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/Support/BranchProbability.h"
+
+namespace llvm {
+
+class FunctionLoweringInfo;
+class MachineBasicBlock;
+
+namespace SwitchCG {
+
+enum CaseClusterKind {
+ /// A cluster of adjacent case labels with the same destination, or just one
+ /// case.
+ CC_Range,
+ /// A cluster of cases suitable for jump table lowering.
+ CC_JumpTable,
+ /// A cluster of cases suitable for bit test lowering.
+ CC_BitTests
+};
+
+/// A cluster of case labels.
+struct CaseCluster {
+ CaseClusterKind Kind;
+ const ConstantInt *Low, *High;
+ union {
+ MachineBasicBlock *MBB;
+ unsigned JTCasesIndex;
+ unsigned BTCasesIndex;
+ };
+ BranchProbability Prob;
+
+ static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
+ MachineBasicBlock *MBB, BranchProbability Prob) {
+ CaseCluster C;
+ C.Kind = CC_Range;
+ C.Low = Low;
+ C.High = High;
+ C.MBB = MBB;
+ C.Prob = Prob;
+ return C;
+ }
+
+ static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
+ unsigned JTCasesIndex, BranchProbability Prob) {
+ CaseCluster C;
+ C.Kind = CC_JumpTable;
+ C.Low = Low;
+ C.High = High;
+ C.JTCasesIndex = JTCasesIndex;
+ C.Prob = Prob;
+ return C;
+ }
+
+ static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
+ unsigned BTCasesIndex, BranchProbability Prob) {
+ CaseCluster C;
+ C.Kind = CC_BitTests;
+ C.Low = Low;
+ C.High = High;
+ C.BTCasesIndex = BTCasesIndex;
+ C.Prob = Prob;
+ return C;
+ }
+};
+
+using CaseClusterVector = std::vector<CaseCluster>;
+using CaseClusterIt = CaseClusterVector::iterator;
+
+/// Sort Clusters and merge adjacent cases.
+void sortAndRangeify(CaseClusterVector &Clusters);
+
+struct CaseBits {
+ uint64_t Mask = 0;
+ MachineBasicBlock *BB = nullptr;
+ unsigned Bits = 0;
+ BranchProbability ExtraProb;
+
+ CaseBits() = default;
+ CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
+ BranchProbability Prob)
+ : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
+};
+
+using CaseBitsVector = std::vector<CaseBits>;
+
+/// This structure is used to communicate between SelectionDAGBuilder and
+/// SDISel for the code generation of additional basic blocks needed by
+/// multi-case switch statements.
+struct CaseBlock {
+ // The condition code to use for the case block's setcc node.
+ // Besides the integer condition codes, this can also be SETTRUE, in which
+ // case no comparison gets emitted.
+ ISD::CondCode CC;
+
+ // The LHS/MHS/RHS of the comparison to emit.
+ // Emit by default LHS op RHS. MHS is used for range comparisons:
+ // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
+ const Value *CmpLHS, *CmpMHS, *CmpRHS;
+
+ // The block to branch to if the setcc is true/false.
+ MachineBasicBlock *TrueBB, *FalseBB;
+
+ // The block into which to emit the code for the setcc and branches.
+ MachineBasicBlock *ThisBB;
+
+ /// The debug location of the instruction this CaseBlock was
+ /// produced from.
+ SDLoc DL;
+
+ // Branch weights.
+ BranchProbability TrueProb, FalseProb;
+
+ CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
+ const Value *cmpmiddle, MachineBasicBlock *truebb,
+ MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
+ BranchProbability trueprob = BranchProbability::getUnknown(),
+ BranchProbability falseprob = BranchProbability::getUnknown())
+ : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
+ TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
+ TrueProb(trueprob), FalseProb(falseprob) {}
+};
+
+struct JumpTable {
+ /// The virtual register containing the index of the jump table entry
+ /// to jump to.
+ unsigned Reg;
+ /// The JumpTableIndex for this jump table in the function.
+ unsigned JTI;
+ /// The MBB into which to emit the code for the indirect jump.
+ MachineBasicBlock *MBB;
+ /// The MBB of the default bb, which is a successor of the range
+ /// check MBB. This is when updating PHI nodes in successors.
+ MachineBasicBlock *Default;
+
+ JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
+ : Reg(R), JTI(J), MBB(M), Default(D) {}
+};
+struct JumpTableHeader {
+ APInt First;
+ APInt Last;
+ const Value *SValue;
+ MachineBasicBlock *HeaderBB;
+ bool Emitted;
+ bool OmitRangeCheck;
+
+ JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
+ bool E = false)
+ : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
+ Emitted(E), OmitRangeCheck(false) {}
+};
+using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
+
+struct BitTestCase {
+ uint64_t Mask;
+ MachineBasicBlock *ThisBB;
+ MachineBasicBlock *TargetBB;
+ BranchProbability ExtraProb;
+
+ BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
+ BranchProbability Prob)
+ : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
+};
+
+using BitTestInfo = SmallVector<BitTestCase, 3>;
+
+struct BitTestBlock {
+ APInt First;
+ APInt Range;
+ const Value *SValue;
+ unsigned Reg;
+ MVT RegVT;
+ bool Emitted;
+ bool ContiguousRange;
+ MachineBasicBlock *Parent;
+ MachineBasicBlock *Default;
+ BitTestInfo Cases;
+ BranchProbability Prob;
+ BranchProbability DefaultProb;
+
+ BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
+ bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
+ BitTestInfo C, BranchProbability Pr)
+ : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
+ RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
+ Cases(std::move(C)), Prob(Pr) {}
+};
+
+/// Return the range of value in [First..Last].
+uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
+ unsigned Last);
+
+/// Return the number of cases in [First..Last].
+uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+ unsigned First, unsigned Last);
+
+struct SwitchWorkListItem {
+ MachineBasicBlock *MBB;
+ CaseClusterIt FirstCluster;
+ CaseClusterIt LastCluster;
+ const ConstantInt *GE;
+ const ConstantInt *LT;
+ BranchProbability DefaultProb;
+};
+using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
+
+class SwitchLowering {
+public:
+ SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
+
+ void init(const TargetLowering &tli, const TargetMachine &tm,
+ const DataLayout &dl) {
+ TLI = &tli;
+ TM = &tm;
+ DL = &dl;
+ }
+
+ /// Vector of CaseBlock structures used to communicate SwitchInst code
+ /// generation information.
+ std::vector<CaseBlock> SwitchCases;
+
+ /// Vector of JumpTable structures used to communicate SwitchInst code
+ /// generation information.
+ std::vector<JumpTableBlock> JTCases;
+
+ /// Vector of BitTestBlock structures used to communicate SwitchInst code
+ /// generation information.
+ std::vector<BitTestBlock> BitTestCases;
+
+ void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB);
+
+ bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
+ unsigned Last, const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
+
+
+ void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
+
+ /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
+ /// decides it's not a good idea.
+ bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
+ const SwitchInst *SI, CaseCluster &BTCluster);
+
+ virtual void addSuccessorWithProb(
+ MachineBasicBlock *Src, MachineBasicBlock *Dst,
+ BranchProbability Prob = BranchProbability::getUnknown()) = 0;
+
+ virtual ~SwitchLowering() = default;
+
+private:
+ const TargetLowering *TLI;
+ const TargetMachine *TM;
+ const DataLayout *DL;
+ FunctionLoweringInfo &FuncInfo;
+};
+
+} // namespace SwitchCG
+} // namespace llvm
+
+#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
+
StackProtector.cpp
StackSlotColoring.cpp
SwiftErrorValueTracking.cpp
+ SwitchLoweringUtils.cpp
TailDuplication.cpp
TailDuplicator.cpp
TargetFrameLoweringImpl.cpp
using namespace llvm;
using namespace PatternMatch;
+using namespace SwitchCG;
#define DEBUG_TYPE "isel"
DL = &DAG.getDataLayout();
Context = DAG.getContext();
LPadToCallSiteMap.clear();
+ SL->init(DAG.getTargetLoweringInfo(), TM, DAG.getDataLayout());
}
void SelectionDAGBuilder::clear() {
CaseBlock CB(Condition, BOp->getOperand(0), BOp->getOperand(1), nullptr,
TBB, FBB, CurBB, getCurSDLoc(), TProb, FProb);
- SwitchCases.push_back(CB);
+ SL->SwitchCases.push_back(CB);
return;
}
}
ISD::CondCode Opc = InvertCond ? ISD::SETNE : ISD::SETEQ;
CaseBlock CB(Opc, Cond, ConstantInt::getTrue(*DAG.getContext()),
nullptr, TBB, FBB, CurBB, getCurSDLoc(), TProb, FProb);
- SwitchCases.push_back(CB);
+ SL->SwitchCases.push_back(CB);
}
void SelectionDAGBuilder::FindMergedConditions(const Value *Cond,
// If the compares in later blocks need to use values not currently
// exported from this block, export them now. This block should always
// be the first entry.
- assert(SwitchCases[0].ThisBB == BrMBB && "Unexpected lowering!");
+ assert(SL->SwitchCases[0].ThisBB == BrMBB && "Unexpected lowering!");
// Allow some cases to be rejected.
- if (ShouldEmitAsBranches(SwitchCases)) {
- for (unsigned i = 1, e = SwitchCases.size(); i != e; ++i) {
- ExportFromCurrentBlock(SwitchCases[i].CmpLHS);
- ExportFromCurrentBlock(SwitchCases[i].CmpRHS);
+ if (ShouldEmitAsBranches(SL->SwitchCases)) {
+ for (unsigned i = 1, e = SL->SwitchCases.size(); i != e; ++i) {
+ ExportFromCurrentBlock(SL->SwitchCases[i].CmpLHS);
+ ExportFromCurrentBlock(SL->SwitchCases[i].CmpRHS);
}
// Emit the branch for this block.
- visitSwitchCase(SwitchCases[0], BrMBB);
- SwitchCases.erase(SwitchCases.begin());
+ visitSwitchCase(SL->SwitchCases[0], BrMBB);
+ SL->SwitchCases.erase(SL->SwitchCases.begin());
return;
}
// Okay, we decided not to do this, remove any inserted MBB's and clear
// SwitchCases.
- for (unsigned i = 1, e = SwitchCases.size(); i != e; ++i)
- FuncInfo.MF->erase(SwitchCases[i].ThisBB);
+ for (unsigned i = 1, e = SL->SwitchCases.size(); i != e; ++i)
+ FuncInfo.MF->erase(SL->SwitchCases[i].ThisBB);
- SwitchCases.clear();
+ SL->SwitchCases.clear();
}
}
}
/// visitJumpTable - Emit JumpTable node in the current MBB
-void SelectionDAGBuilder::visitJumpTable(JumpTable &JT) {
+void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
// Emit the code for the jump table
assert(JT.Reg != -1U && "Should lower JT Header first!");
EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
/// visitJumpTableHeader - This function emits necessary code to produce index
/// in the JumpTable from switch case.
-void SelectionDAGBuilder::visitJumpTableHeader(JumpTable &JT,
+void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
JumpTableHeader &JTH,
MachineBasicBlock *SwitchBB) {
SDLoc dl = getCurSDLoc();
setValue(&LP, Res);
}
-void SelectionDAGBuilder::sortAndRangeify(CaseClusterVector &Clusters) {
-#ifndef NDEBUG
- for (const CaseCluster &CC : Clusters)
- assert(CC.Low == CC.High && "Input clusters must be single-case");
-#endif
-
- llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
- return a.Low->getValue().slt(b.Low->getValue());
- });
-
- // Merge adjacent clusters with the same destination.
- const unsigned N = Clusters.size();
- unsigned DstIndex = 0;
- for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
- CaseCluster &CC = Clusters[SrcIndex];
- const ConstantInt *CaseVal = CC.Low;
- MachineBasicBlock *Succ = CC.MBB;
-
- if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
- (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
- // If this case has the same successor and is a neighbour, merge it into
- // the previous cluster.
- Clusters[DstIndex - 1].High = CaseVal;
- Clusters[DstIndex - 1].Prob += CC.Prob;
- } else {
- std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
- sizeof(Clusters[SrcIndex]));
- }
- }
- Clusters.resize(DstIndex);
-}
-
void SelectionDAGBuilder::UpdateSplitBlock(MachineBasicBlock *First,
MachineBasicBlock *Last) {
// Update JTCases.
- for (unsigned i = 0, e = JTCases.size(); i != e; ++i)
- if (JTCases[i].first.HeaderBB == First)
- JTCases[i].first.HeaderBB = Last;
+ for (unsigned i = 0, e = SL->JTCases.size(); i != e; ++i)
+ if (SL->JTCases[i].first.HeaderBB == First)
+ SL->JTCases[i].first.HeaderBB = Last;
// Update BitTestCases.
- for (unsigned i = 0, e = BitTestCases.size(); i != e; ++i)
- if (BitTestCases[i].Parent == First)
- BitTestCases[i].Parent = Last;
+ for (unsigned i = 0, e = SL->BitTestCases.size(); i != e; ++i)
+ if (SL->BitTestCases[i].Parent == First)
+ SL->BitTestCases[i].Parent = Last;
}
void SelectionDAGBuilder::visitIndirectBr(const IndirectBrInst &I) {
HasTailCall = true;
}
-uint64_t
-SelectionDAGBuilder::getJumpTableRange(const CaseClusterVector &Clusters,
- unsigned First, unsigned Last) const {
- assert(Last >= First);
- const APInt &LowCase = Clusters[First].Low->getValue();
- const APInt &HighCase = Clusters[Last].High->getValue();
- assert(LowCase.getBitWidth() == HighCase.getBitWidth());
-
- // FIXME: A range of consecutive cases has 100% density, but only requires one
- // comparison to lower. We should discriminate against such consecutive ranges
- // in jump tables.
-
- return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
-}
-
-uint64_t SelectionDAGBuilder::getJumpTableNumCases(
- const SmallVectorImpl<unsigned> &TotalCases, unsigned First,
- unsigned Last) const {
- assert(Last >= First);
- assert(TotalCases[Last] >= TotalCases[First]);
- uint64_t NumCases =
- TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
- return NumCases;
-}
-
-bool SelectionDAGBuilder::buildJumpTable(const CaseClusterVector &Clusters,
- unsigned First, unsigned Last,
- const SwitchInst *SI,
- MachineBasicBlock *DefaultMBB,
- CaseCluster &JTCluster) {
- assert(First <= Last);
-
- auto Prob = BranchProbability::getZero();
- unsigned NumCmps = 0;
- std::vector<MachineBasicBlock*> Table;
- DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
-
- // Initialize probabilities in JTProbs.
- for (unsigned I = First; I <= Last; ++I)
- JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
-
- for (unsigned I = First; I <= Last; ++I) {
- assert(Clusters[I].Kind == CC_Range);
- Prob += Clusters[I].Prob;
- const APInt &Low = Clusters[I].Low->getValue();
- const APInt &High = Clusters[I].High->getValue();
- NumCmps += (Low == High) ? 1 : 2;
- if (I != First) {
- // Fill the gap between this and the previous cluster.
- const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
- assert(PreviousHigh.slt(Low));
- uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
- for (uint64_t J = 0; J < Gap; J++)
- Table.push_back(DefaultMBB);
- }
- uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
- for (uint64_t J = 0; J < ClusterSize; ++J)
- Table.push_back(Clusters[I].MBB);
- JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
- }
-
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- unsigned NumDests = JTProbs.size();
- if (TLI.isSuitableForBitTests(
- NumDests, NumCmps, Clusters[First].Low->getValue(),
- Clusters[Last].High->getValue(), DAG.getDataLayout())) {
- // Clusters[First..Last] should be lowered as bit tests instead.
- return false;
- }
-
- // Create the MBB that will load from and jump through the table.
- // Note: We create it here, but it's not inserted into the function yet.
- MachineFunction *CurMF = FuncInfo.MF;
- MachineBasicBlock *JumpTableMBB =
- CurMF->CreateMachineBasicBlock(SI->getParent());
-
- // Add successors. Note: use table order for determinism.
- SmallPtrSet<MachineBasicBlock *, 8> Done;
- for (MachineBasicBlock *Succ : Table) {
- if (Done.count(Succ))
- continue;
- addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
- Done.insert(Succ);
- }
- JumpTableMBB->normalizeSuccProbs();
-
- unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI.getJumpTableEncoding())
- ->createJumpTableIndex(Table);
-
- // Set up the jump table info.
- JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
- JumpTableHeader JTH(Clusters[First].Low->getValue(),
- Clusters[Last].High->getValue(), SI->getCondition(),
- nullptr, false);
- JTCases.emplace_back(std::move(JTH), std::move(JT));
-
- JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
- JTCases.size() - 1, Prob);
- return true;
-}
-
-void SelectionDAGBuilder::findJumpTables(CaseClusterVector &Clusters,
- const SwitchInst *SI,
- MachineBasicBlock *DefaultMBB) {
-#ifndef NDEBUG
- // Clusters must be non-empty, sorted, and only contain Range clusters.
- assert(!Clusters.empty());
- for (CaseCluster &C : Clusters)
- assert(C.Kind == CC_Range);
- for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
- assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
-#endif
-
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- if (!TLI.areJTsAllowed(SI->getParent()->getParent()))
- return;
-
- const int64_t N = Clusters.size();
- const unsigned MinJumpTableEntries = TLI.getMinimumJumpTableEntries();
- const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
-
- if (N < 2 || N < MinJumpTableEntries)
- return;
-
- // TotalCases[i]: Total nbr of cases in Clusters[0..i].
- SmallVector<unsigned, 8> TotalCases(N);
- for (unsigned i = 0; i < N; ++i) {
- const APInt &Hi = Clusters[i].High->getValue();
- const APInt &Lo = Clusters[i].Low->getValue();
- TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
- if (i != 0)
- TotalCases[i] += TotalCases[i - 1];
- }
-
- // Cheap case: the whole range may be suitable for jump table.
- uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
- uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
- assert(NumCases < UINT64_MAX / 100);
- assert(Range >= NumCases);
- if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
- CaseCluster JTCluster;
- if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
- Clusters[0] = JTCluster;
- Clusters.resize(1);
- return;
- }
- }
-
- // The algorithm below is not suitable for -O0.
- if (TM.getOptLevel() == CodeGenOpt::None)
- return;
-
- // Split Clusters into minimum number of dense partitions. The algorithm uses
- // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
- // for the Case Statement'" (1994), but builds the MinPartitions array in
- // reverse order to make it easier to reconstruct the partitions in ascending
- // order. In the choice between two optimal partitionings, it picks the one
- // which yields more jump tables.
-
- // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
- SmallVector<unsigned, 8> MinPartitions(N);
- // LastElement[i] is the last element of the partition starting at i.
- SmallVector<unsigned, 8> LastElement(N);
- // PartitionsScore[i] is used to break ties when choosing between two
- // partitionings resulting in the same number of partitions.
- SmallVector<unsigned, 8> PartitionsScore(N);
- // For PartitionsScore, a small number of comparisons is considered as good as
- // a jump table and a single comparison is considered better than a jump
- // table.
- enum PartitionScores : unsigned {
- NoTable = 0,
- Table = 1,
- FewCases = 1,
- SingleCase = 2
- };
-
- // Base case: There is only one way to partition Clusters[N-1].
- MinPartitions[N - 1] = 1;
- LastElement[N - 1] = N - 1;
- PartitionsScore[N - 1] = PartitionScores::SingleCase;
-
- // Note: loop indexes are signed to avoid underflow.
- for (int64_t i = N - 2; i >= 0; i--) {
- // Find optimal partitioning of Clusters[i..N-1].
- // Baseline: Put Clusters[i] into a partition on its own.
- MinPartitions[i] = MinPartitions[i + 1] + 1;
- LastElement[i] = i;
- PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
-
- // Search for a solution that results in fewer partitions.
- for (int64_t j = N - 1; j > i; j--) {
- // Try building a partition from Clusters[i..j].
- uint64_t Range = getJumpTableRange(Clusters, i, j);
- uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
- assert(NumCases < UINT64_MAX / 100);
- assert(Range >= NumCases);
- if (TLI.isSuitableForJumpTable(SI, NumCases, Range)) {
- unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
- unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
- int64_t NumEntries = j - i + 1;
-
- if (NumEntries == 1)
- Score += PartitionScores::SingleCase;
- else if (NumEntries <= SmallNumberOfEntries)
- Score += PartitionScores::FewCases;
- else if (NumEntries >= MinJumpTableEntries)
- Score += PartitionScores::Table;
-
- // If this leads to fewer partitions, or to the same number of
- // partitions with better score, it is a better partitioning.
- if (NumPartitions < MinPartitions[i] ||
- (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
- MinPartitions[i] = NumPartitions;
- LastElement[i] = j;
- PartitionsScore[i] = Score;
- }
- }
- }
- }
-
- // Iterate over the partitions, replacing some with jump tables in-place.
- unsigned DstIndex = 0;
- for (unsigned First = 0, Last; First < N; First = Last + 1) {
- Last = LastElement[First];
- assert(Last >= First);
- assert(DstIndex <= First);
- unsigned NumClusters = Last - First + 1;
-
- CaseCluster JTCluster;
- if (NumClusters >= MinJumpTableEntries &&
- buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
- Clusters[DstIndex++] = JTCluster;
- } else {
- for (unsigned I = First; I <= Last; ++I)
- std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
- }
- }
- Clusters.resize(DstIndex);
-}
-
-bool SelectionDAGBuilder::buildBitTests(CaseClusterVector &Clusters,
- unsigned First, unsigned Last,
- const SwitchInst *SI,
- CaseCluster &BTCluster) {
- assert(First <= Last);
- if (First == Last)
- return false;
-
- BitVector Dests(FuncInfo.MF->getNumBlockIDs());
- unsigned NumCmps = 0;
- for (int64_t I = First; I <= Last; ++I) {
- assert(Clusters[I].Kind == CC_Range);
- Dests.set(Clusters[I].MBB->getNumber());
- NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
- }
- unsigned NumDests = Dests.count();
-
- APInt Low = Clusters[First].Low->getValue();
- APInt High = Clusters[Last].High->getValue();
- assert(Low.slt(High));
-
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- const DataLayout &DL = DAG.getDataLayout();
- if (!TLI.isSuitableForBitTests(NumDests, NumCmps, Low, High, DL))
- return false;
-
- APInt LowBound;
- APInt CmpRange;
-
- const int BitWidth = TLI.getPointerTy(DL).getSizeInBits();
- assert(TLI.rangeFitsInWord(Low, High, DL) &&
- "Case range must fit in bit mask!");
-
- // Check if the clusters cover a contiguous range such that no value in the
- // range will jump to the default statement.
- bool ContiguousRange = true;
- for (int64_t I = First + 1; I <= Last; ++I) {
- if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
- ContiguousRange = false;
- break;
- }
- }
-
- if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
- // Optimize the case where all the case values fit in a word without having
- // to subtract minValue. In this case, we can optimize away the subtraction.
- LowBound = APInt::getNullValue(Low.getBitWidth());
- CmpRange = High;
- ContiguousRange = false;
- } else {
- LowBound = Low;
- CmpRange = High - Low;
- }
-
- CaseBitsVector CBV;
- auto TotalProb = BranchProbability::getZero();
- for (unsigned i = First; i <= Last; ++i) {
- // Find the CaseBits for this destination.
- unsigned j;
- for (j = 0; j < CBV.size(); ++j)
- if (CBV[j].BB == Clusters[i].MBB)
- break;
- if (j == CBV.size())
- CBV.push_back(
- CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
- CaseBits *CB = &CBV[j];
-
- // Update Mask, Bits and ExtraProb.
- uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
- uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
- assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
- CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
- CB->Bits += Hi - Lo + 1;
- CB->ExtraProb += Clusters[i].Prob;
- TotalProb += Clusters[i].Prob;
- }
-
- BitTestInfo BTI;
- llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
- // Sort by probability first, number of bits second, bit mask third.
- if (a.ExtraProb != b.ExtraProb)
- return a.ExtraProb > b.ExtraProb;
- if (a.Bits != b.Bits)
- return a.Bits > b.Bits;
- return a.Mask < b.Mask;
- });
-
- for (auto &CB : CBV) {
- MachineBasicBlock *BitTestBB =
- FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
- BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
- }
- BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
- SI->getCondition(), -1U, MVT::Other, false,
- ContiguousRange, nullptr, nullptr, std::move(BTI),
- TotalProb);
-
- BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
- BitTestCases.size() - 1, TotalProb);
- return true;
-}
-
-void SelectionDAGBuilder::findBitTestClusters(CaseClusterVector &Clusters,
- const SwitchInst *SI) {
-// Partition Clusters into as few subsets as possible, where each subset has a
-// range that fits in a machine word and has <= 3 unique destinations.
-
-#ifndef NDEBUG
- // Clusters must be sorted and contain Range or JumpTable clusters.
- assert(!Clusters.empty());
- assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
- for (const CaseCluster &C : Clusters)
- assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
- for (unsigned i = 1; i < Clusters.size(); ++i)
- assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
-#endif
-
- // The algorithm below is not suitable for -O0.
- if (TM.getOptLevel() == CodeGenOpt::None)
- return;
-
- // If target does not have legal shift left, do not emit bit tests at all.
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- const DataLayout &DL = DAG.getDataLayout();
-
- EVT PTy = TLI.getPointerTy(DL);
- if (!TLI.isOperationLegal(ISD::SHL, PTy))
- return;
-
- int BitWidth = PTy.getSizeInBits();
- const int64_t N = Clusters.size();
-
- // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
- SmallVector<unsigned, 8> MinPartitions(N);
- // LastElement[i] is the last element of the partition starting at i.
- SmallVector<unsigned, 8> LastElement(N);
-
- // FIXME: This might not be the best algorithm for finding bit test clusters.
-
- // Base case: There is only one way to partition Clusters[N-1].
- MinPartitions[N - 1] = 1;
- LastElement[N - 1] = N - 1;
-
- // Note: loop indexes are signed to avoid underflow.
- for (int64_t i = N - 2; i >= 0; --i) {
- // Find optimal partitioning of Clusters[i..N-1].
- // Baseline: Put Clusters[i] into a partition on its own.
- MinPartitions[i] = MinPartitions[i + 1] + 1;
- LastElement[i] = i;
-
- // Search for a solution that results in fewer partitions.
- // Note: the search is limited by BitWidth, reducing time complexity.
- for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
- // Try building a partition from Clusters[i..j].
-
- // Check the range.
- if (!TLI.rangeFitsInWord(Clusters[i].Low->getValue(),
- Clusters[j].High->getValue(), DL))
- continue;
-
- // Check nbr of destinations and cluster types.
- // FIXME: This works, but doesn't seem very efficient.
- bool RangesOnly = true;
- BitVector Dests(FuncInfo.MF->getNumBlockIDs());
- for (int64_t k = i; k <= j; k++) {
- if (Clusters[k].Kind != CC_Range) {
- RangesOnly = false;
- break;
- }
- Dests.set(Clusters[k].MBB->getNumber());
- }
- if (!RangesOnly || Dests.count() > 3)
- break;
-
- // Check if it's a better partition.
- unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
- if (NumPartitions < MinPartitions[i]) {
- // Found a better partition.
- MinPartitions[i] = NumPartitions;
- LastElement[i] = j;
- }
- }
- }
-
- // Iterate over the partitions, replacing with bit-test clusters in-place.
- unsigned DstIndex = 0;
- for (unsigned First = 0, Last; First < N; First = Last + 1) {
- Last = LastElement[First];
- assert(First <= Last);
- assert(DstIndex <= First);
-
- CaseCluster BitTestCluster;
- if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
- Clusters[DstIndex++] = BitTestCluster;
- } else {
- size_t NumClusters = Last - First + 1;
- std::memmove(&Clusters[DstIndex], &Clusters[First],
- sizeof(Clusters[0]) * NumClusters);
- DstIndex += NumClusters;
- }
- }
- Clusters.resize(DstIndex);
-}
-
void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
MachineBasicBlock *SwitchMBB,
MachineBasicBlock *DefaultMBB) {
switch (I->Kind) {
case CC_JumpTable: {
// FIXME: Optimize away range check based on pivot comparisons.
- JumpTableHeader *JTH = &JTCases[I->JTCasesIndex].first;
- JumpTable *JT = &JTCases[I->JTCasesIndex].second;
+ JumpTableHeader *JTH = &SL->JTCases[I->JTCasesIndex].first;
+ SwitchCG::JumpTable *JT = &SL->JTCases[I->JTCasesIndex].second;
// The jump block hasn't been inserted yet; insert it here.
MachineBasicBlock *JumpMBB = JT->MBB;
// FIXME: If Fallthrough is unreachable, skip the range check.
// FIXME: Optimize away range check based on pivot comparisons.
- BitTestBlock *BTB = &BitTestCases[I->BTCasesIndex];
+ BitTestBlock *BTB = &SL->BitTestCases[I->BTCasesIndex];
// The bit test blocks haven't been inserted yet; insert them here.
for (BitTestCase &BTC : BTB->Cases)
if (CurMBB == SwitchMBB)
visitSwitchCase(CB, SwitchMBB);
else
- SwitchCases.push_back(CB);
+ SL->SwitchCases.push_back(CB);
break;
}
if (W.MBB == SwitchMBB)
visitSwitchCase(CB, SwitchMBB);
else
- SwitchCases.push_back(CB);
+ SL->SwitchCases.push_back(CB);
}
// Scale CaseProb after peeling a case with the probablity of PeeledCaseProb
return;
}
- findJumpTables(Clusters, &SI, DefaultMBB);
- findBitTestClusters(Clusters, &SI);
+ SL->findJumpTables(Clusters, &SI, DefaultMBB);
+ SL->findBitTestClusters(Clusters, &SI);
LLVM_DEBUG({
dbgs() << "Case clusters: ";
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
+#include "llvm/CodeGen/SwitchLoweringUtils.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/CallSite.h"
/// create.
unsigned SDNodeOrder;
- enum CaseClusterKind {
- /// A cluster of adjacent case labels with the same destination, or just one
- /// case.
- CC_Range,
- /// A cluster of cases suitable for jump table lowering.
- CC_JumpTable,
- /// A cluster of cases suitable for bit test lowering.
- CC_BitTests
- };
-
- /// A cluster of case labels.
- struct CaseCluster {
- CaseClusterKind Kind;
- const ConstantInt *Low, *High;
- union {
- MachineBasicBlock *MBB;
- unsigned JTCasesIndex;
- unsigned BTCasesIndex;
- };
- BranchProbability Prob;
-
- static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
- MachineBasicBlock *MBB, BranchProbability Prob) {
- CaseCluster C;
- C.Kind = CC_Range;
- C.Low = Low;
- C.High = High;
- C.MBB = MBB;
- C.Prob = Prob;
- return C;
- }
-
- static CaseCluster jumpTable(const ConstantInt *Low,
- const ConstantInt *High, unsigned JTCasesIndex,
- BranchProbability Prob) {
- CaseCluster C;
- C.Kind = CC_JumpTable;
- C.Low = Low;
- C.High = High;
- C.JTCasesIndex = JTCasesIndex;
- C.Prob = Prob;
- return C;
- }
-
- static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
- unsigned BTCasesIndex, BranchProbability Prob) {
- CaseCluster C;
- C.Kind = CC_BitTests;
- C.Low = Low;
- C.High = High;
- C.BTCasesIndex = BTCasesIndex;
- C.Prob = Prob;
- return C;
- }
- };
-
- using CaseClusterVector = std::vector<CaseCluster>;
- using CaseClusterIt = CaseClusterVector::iterator;
-
- struct CaseBits {
- uint64_t Mask = 0;
- MachineBasicBlock* BB = nullptr;
- unsigned Bits = 0;
- BranchProbability ExtraProb;
-
- CaseBits() = default;
- CaseBits(uint64_t mask, MachineBasicBlock* bb, unsigned bits,
- BranchProbability Prob):
- Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
- };
-
- using CaseBitsVector = std::vector<CaseBits>;
-
- /// Sort Clusters and merge adjacent cases.
- void sortAndRangeify(CaseClusterVector &Clusters);
-
- /// This structure is used to communicate between SelectionDAGBuilder and
- /// SDISel for the code generation of additional basic blocks needed by
- /// multi-case switch statements.
- struct CaseBlock {
- // The condition code to use for the case block's setcc node.
- // Besides the integer condition codes, this can also be SETTRUE, in which
- // case no comparison gets emitted.
- ISD::CondCode CC;
-
- // The LHS/MHS/RHS of the comparison to emit.
- // Emit by default LHS op RHS. MHS is used for range comparisons:
- // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
- const Value *CmpLHS, *CmpMHS, *CmpRHS;
-
- // The block to branch to if the setcc is true/false.
- MachineBasicBlock *TrueBB, *FalseBB;
-
- // The block into which to emit the code for the setcc and branches.
- MachineBasicBlock *ThisBB;
-
- /// The debug location of the instruction this CaseBlock was
- /// produced from.
- SDLoc DL;
-
- // Branch weights.
- BranchProbability TrueProb, FalseProb;
-
- CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
- const Value *cmpmiddle, MachineBasicBlock *truebb,
- MachineBasicBlock *falsebb, MachineBasicBlock *me,
- SDLoc dl,
- BranchProbability trueprob = BranchProbability::getUnknown(),
- BranchProbability falseprob = BranchProbability::getUnknown())
- : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
- TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
- TrueProb(trueprob), FalseProb(falseprob) {}
- };
-
- struct JumpTable {
- /// The virtual register containing the index of the jump table entry
- /// to jump to.
- unsigned Reg;
- /// The JumpTableIndex for this jump table in the function.
- unsigned JTI;
- /// The MBB into which to emit the code for the indirect jump.
- MachineBasicBlock *MBB;
- /// The MBB of the default bb, which is a successor of the range
- /// check MBB. This is when updating PHI nodes in successors.
- MachineBasicBlock *Default;
-
- JumpTable(unsigned R, unsigned J, MachineBasicBlock *M,
- MachineBasicBlock *D): Reg(R), JTI(J), MBB(M), Default(D) {}
- };
- struct JumpTableHeader {
- APInt First;
- APInt Last;
- const Value *SValue;
- MachineBasicBlock *HeaderBB;
- bool Emitted;
- bool OmitRangeCheck;
-
- JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
- bool E = false)
- : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
- Emitted(E), OmitRangeCheck(false) {}
- };
- using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
-
- struct BitTestCase {
- uint64_t Mask;
- MachineBasicBlock *ThisBB;
- MachineBasicBlock *TargetBB;
- BranchProbability ExtraProb;
-
- BitTestCase(uint64_t M, MachineBasicBlock* T, MachineBasicBlock* Tr,
- BranchProbability Prob):
- Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
- };
-
- using BitTestInfo = SmallVector<BitTestCase, 3>;
-
- struct BitTestBlock {
- APInt First;
- APInt Range;
- const Value *SValue;
- unsigned Reg;
- MVT RegVT;
- bool Emitted;
- bool ContiguousRange;
- MachineBasicBlock *Parent;
- MachineBasicBlock *Default;
- BitTestInfo Cases;
- BranchProbability Prob;
- BranchProbability DefaultProb;
-
- BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT,
- bool E, bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
- BitTestInfo C, BranchProbability Pr)
- : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
- RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
- Cases(std::move(C)), Prob(Pr) {}
- };
-
- /// Return the range of value in [First..Last].
- uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
- unsigned Last) const;
-
- /// Return the number of cases in [First..Last].
- uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
- unsigned First, unsigned Last) const;
-
- /// Build a jump table cluster from Clusters[First..Last]. Returns false if it
- /// decides it's not a good idea.
- bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
- unsigned Last, const SwitchInst *SI,
- MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
-
- /// Find clusters of cases suitable for jump table lowering.
- void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
- MachineBasicBlock *DefaultMBB);
-
- /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
- /// decides it's not a good idea.
- bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
- const SwitchInst *SI, CaseCluster &BTCluster);
-
- /// Find clusters of cases suitable for bit test lowering.
- void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
-
- struct SwitchWorkListItem {
- MachineBasicBlock *MBB;
- CaseClusterIt FirstCluster;
- CaseClusterIt LastCluster;
- const ConstantInt *GE;
- const ConstantInt *LT;
- BranchProbability DefaultProb;
- };
- using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
-
/// Determine the rank by weight of CC in [First,Last]. If CC has more weight
/// than each cluster in the range, its rank is 0.
- static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
- CaseClusterIt Last);
+ unsigned caseClusterRank(const SwitchCG::CaseCluster &CC,
+ SwitchCG::CaseClusterIt First,
+ SwitchCG::CaseClusterIt Last);
/// Emit comparison and split W into two subtrees.
- void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W,
- Value *Cond, MachineBasicBlock *SwitchMBB);
+ void splitWorkItem(SwitchCG::SwitchWorkList &WorkList,
+ const SwitchCG::SwitchWorkListItem &W, Value *Cond,
+ MachineBasicBlock *SwitchMBB);
/// Lower W.
- void lowerWorkItem(SwitchWorkListItem W, Value *Cond,
+ void lowerWorkItem(SwitchCG::SwitchWorkListItem W, Value *Cond,
MachineBasicBlock *SwitchMBB,
MachineBasicBlock *DefaultMBB);
/// Peel the top probability case if it exceeds the threshold
- MachineBasicBlock *peelDominantCaseCluster(const SwitchInst &SI,
- CaseClusterVector &Clusters,
- BranchProbability &PeeledCaseProb);
+ MachineBasicBlock *
+ peelDominantCaseCluster(const SwitchInst &SI,
+ SwitchCG::CaseClusterVector &Clusters,
+ BranchProbability &PeeledCaseProb);
/// A class which encapsulates all of the information needed to generate a
/// stack protector check and signals to isel via its state being initialized
AliasAnalysis *AA = nullptr;
const TargetLibraryInfo *LibInfo;
- /// Vector of CaseBlock structures used to communicate SwitchInst code
- /// generation information.
- std::vector<CaseBlock> SwitchCases;
+ class SDAGSwitchLowering : public SwitchCG::SwitchLowering {
+ public:
+ SDAGSwitchLowering(SelectionDAGBuilder *sdb, FunctionLoweringInfo &funcinfo)
+ : SwitchCG::SwitchLowering(funcinfo), SDB(sdb) {}
- /// Vector of JumpTable structures used to communicate SwitchInst code
- /// generation information.
- std::vector<JumpTableBlock> JTCases;
+ virtual void addSuccessorWithProb(
+ MachineBasicBlock *Src, MachineBasicBlock *Dst,
+ BranchProbability Prob = BranchProbability::getUnknown()) override {
+ SDB->addSuccessorWithProb(Src, Dst, Prob);
+ }
+
+ private:
+ SelectionDAGBuilder *SDB;
+ };
- /// Vector of BitTestBlock structures used to communicate SwitchInst code
- /// generation information.
- std::vector<BitTestBlock> BitTestCases;
+ std::unique_ptr<SDAGSwitchLowering> SL;
/// A StackProtectorDescriptor structure used to communicate stack protector
/// information in between SelectBasicBlock and FinishBasicBlock.
SelectionDAGBuilder(SelectionDAG &dag, FunctionLoweringInfo &funcinfo,
SwiftErrorValueTracking &swifterror, CodeGenOpt::Level ol)
: SDNodeOrder(LowestSDNodeOrder), TM(dag.getTarget()), DAG(dag),
- FuncInfo(funcinfo), SwiftError(swifterror) {}
+ SL(make_unique<SDAGSwitchLowering>(this, funcinfo)), FuncInfo(funcinfo),
+ SwiftError(swifterror) {}
void init(GCFunctionInfo *gfi, AliasAnalysis *AA,
const TargetLibraryInfo *li);
MachineBasicBlock *SwitchBB,
BranchProbability TProb, BranchProbability FProb,
bool InvertCond);
- bool ShouldEmitAsBranches(const std::vector<CaseBlock> &Cases);
+ bool ShouldEmitAsBranches(const std::vector<SwitchCG::CaseBlock> &Cases);
bool isExportableFromCurrentBlock(const Value *V, const BasicBlock *FromBB);
void CopyToExportRegsIfNeeded(const Value *V);
void ExportFromCurrentBlock(const Value *V);
BranchProbability Prob = BranchProbability::getUnknown());
public:
- void visitSwitchCase(CaseBlock &CB,
- MachineBasicBlock *SwitchBB);
+ void visitSwitchCase(SwitchCG::CaseBlock &CB, MachineBasicBlock *SwitchBB);
void visitSPDescriptorParent(StackProtectorDescriptor &SPD,
MachineBasicBlock *ParentBB);
void visitSPDescriptorFailure(StackProtectorDescriptor &SPD);
- void visitBitTestHeader(BitTestBlock &B, MachineBasicBlock *SwitchBB);
- void visitBitTestCase(BitTestBlock &BB,
- MachineBasicBlock* NextMBB,
- BranchProbability BranchProbToNext,
- unsigned Reg,
- BitTestCase &B,
- MachineBasicBlock *SwitchBB);
- void visitJumpTable(JumpTable &JT);
- void visitJumpTableHeader(JumpTable &JT, JumpTableHeader &JTH,
+ void visitBitTestHeader(SwitchCG::BitTestBlock &B,
+ MachineBasicBlock *SwitchBB);
+ void visitBitTestCase(SwitchCG::BitTestBlock &BB, MachineBasicBlock *NextMBB,
+ BranchProbability BranchProbToNext, unsigned Reg,
+ SwitchCG::BitTestCase &B, MachineBasicBlock *SwitchBB);
+ void visitJumpTable(SwitchCG::JumpTable &JT);
+ void visitJumpTableHeader(SwitchCG::JumpTable &JT,
+ SwitchCG::JumpTableHeader &JTH,
MachineBasicBlock *SwitchBB);
private:
}
// Lower each BitTestBlock.
- for (auto &BTB : SDB->BitTestCases) {
+ for (auto &BTB : SDB->SL->BitTestCases) {
// Lower header first, if it wasn't already lowered
if (!BTB.Emitted) {
// Set the current basic block to the mbb we wish to insert the code into
}
}
}
- SDB->BitTestCases.clear();
+ SDB->SL->BitTestCases.clear();
// If the JumpTable record is filled in, then we need to emit a jump table.
// Updating the PHI nodes is tricky in this case, since we need to determine
// whether the PHI is a successor of the range check MBB or the jump table MBB
- for (unsigned i = 0, e = SDB->JTCases.size(); i != e; ++i) {
+ for (unsigned i = 0, e = SDB->SL->JTCases.size(); i != e; ++i) {
// Lower header first, if it wasn't already lowered
- if (!SDB->JTCases[i].first.Emitted) {
+ if (!SDB->SL->JTCases[i].first.Emitted) {
// Set the current basic block to the mbb we wish to insert the code into
- FuncInfo->MBB = SDB->JTCases[i].first.HeaderBB;
+ FuncInfo->MBB = SDB->SL->JTCases[i].first.HeaderBB;
FuncInfo->InsertPt = FuncInfo->MBB->end();
// Emit the code
- SDB->visitJumpTableHeader(SDB->JTCases[i].second, SDB->JTCases[i].first,
- FuncInfo->MBB);
+ SDB->visitJumpTableHeader(SDB->SL->JTCases[i].second,
+ SDB->SL->JTCases[i].first, FuncInfo->MBB);
CurDAG->setRoot(SDB->getRoot());
SDB->clear();
CodeGenAndEmitDAG();
}
// Set the current basic block to the mbb we wish to insert the code into
- FuncInfo->MBB = SDB->JTCases[i].second.MBB;
+ FuncInfo->MBB = SDB->SL->JTCases[i].second.MBB;
FuncInfo->InsertPt = FuncInfo->MBB->end();
// Emit the code
- SDB->visitJumpTable(SDB->JTCases[i].second);
+ SDB->visitJumpTable(SDB->SL->JTCases[i].second);
CurDAG->setRoot(SDB->getRoot());
SDB->clear();
CodeGenAndEmitDAG();
assert(PHI->isPHI() &&
"This is not a machine PHI node that we are updating!");
// "default" BB. We can go there only from header BB.
- if (PHIBB == SDB->JTCases[i].second.Default)
+ if (PHIBB == SDB->SL->JTCases[i].second.Default)
PHI.addReg(FuncInfo->PHINodesToUpdate[pi].second)
- .addMBB(SDB->JTCases[i].first.HeaderBB);
+ .addMBB(SDB->SL->JTCases[i].first.HeaderBB);
// JT BB. Just iterate over successors here
if (FuncInfo->MBB->isSuccessor(PHIBB))
PHI.addReg(FuncInfo->PHINodesToUpdate[pi].second).addMBB(FuncInfo->MBB);
}
}
- SDB->JTCases.clear();
+ SDB->SL->JTCases.clear();
// If we generated any switch lowering information, build and codegen any
// additional DAGs necessary.
- for (unsigned i = 0, e = SDB->SwitchCases.size(); i != e; ++i) {
+ for (unsigned i = 0, e = SDB->SL->SwitchCases.size(); i != e; ++i) {
// Set the current basic block to the mbb we wish to insert the code into
- FuncInfo->MBB = SDB->SwitchCases[i].ThisBB;
+ FuncInfo->MBB = SDB->SL->SwitchCases[i].ThisBB;
FuncInfo->InsertPt = FuncInfo->MBB->end();
// Determine the unique successors.
SmallVector<MachineBasicBlock *, 2> Succs;
- Succs.push_back(SDB->SwitchCases[i].TrueBB);
- if (SDB->SwitchCases[i].TrueBB != SDB->SwitchCases[i].FalseBB)
- Succs.push_back(SDB->SwitchCases[i].FalseBB);
+ Succs.push_back(SDB->SL->SwitchCases[i].TrueBB);
+ if (SDB->SL->SwitchCases[i].TrueBB != SDB->SL->SwitchCases[i].FalseBB)
+ Succs.push_back(SDB->SL->SwitchCases[i].FalseBB);
// Emit the code. Note that this could result in FuncInfo->MBB being split.
- SDB->visitSwitchCase(SDB->SwitchCases[i], FuncInfo->MBB);
+ SDB->visitSwitchCase(SDB->SL->SwitchCases[i], FuncInfo->MBB);
CurDAG->setRoot(SDB->getRoot());
SDB->clear();
CodeGenAndEmitDAG();
}
}
}
- SDB->SwitchCases.clear();
+ SDB->SL->SwitchCases.clear();
}
/// Create the scheduler. If a specific scheduler was specified
--- /dev/null
+//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains switch inst lowering optimizations and utilities for
+// codegen, so that it can be used for both SelectionDAG and GlobalISel.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
+#include "llvm/CodeGen/SwitchLoweringUtils.h"
+
+using namespace llvm;
+using namespace SwitchCG;
+
+uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
+ unsigned First, unsigned Last) {
+ assert(Last >= First);
+ const APInt &LowCase = Clusters[First].Low->getValue();
+ const APInt &HighCase = Clusters[Last].High->getValue();
+ assert(LowCase.getBitWidth() == HighCase.getBitWidth());
+
+ // FIXME: A range of consecutive cases has 100% density, but only requires one
+ // comparison to lower. We should discriminate against such consecutive ranges
+ // in jump tables.
+
+ return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
+}
+
+uint64_t
+SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
+ unsigned First, unsigned Last) {
+ assert(Last >= First);
+ assert(TotalCases[Last] >= TotalCases[First]);
+ uint64_t NumCases =
+ TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
+ return NumCases;
+}
+
+void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
+ const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB) {
+#ifndef NDEBUG
+ // Clusters must be non-empty, sorted, and only contain Range clusters.
+ assert(!Clusters.empty());
+ for (CaseCluster &C : Clusters)
+ assert(C.Kind == CC_Range);
+ for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
+ assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+ if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
+ return;
+
+ const int64_t N = Clusters.size();
+ const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
+ const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
+
+ if (N < 2 || N < MinJumpTableEntries)
+ return;
+
+ // TotalCases[i]: Total nbr of cases in Clusters[0..i].
+ SmallVector<unsigned, 8> TotalCases(N);
+ for (unsigned i = 0; i < N; ++i) {
+ const APInt &Hi = Clusters[i].High->getValue();
+ const APInt &Lo = Clusters[i].Low->getValue();
+ TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
+ if (i != 0)
+ TotalCases[i] += TotalCases[i - 1];
+ }
+
+ // Cheap case: the whole range may be suitable for jump table.
+ uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
+ uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
+ assert(NumCases < UINT64_MAX / 100);
+ assert(Range >= NumCases);
+ if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+ CaseCluster JTCluster;
+ if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
+ Clusters[0] = JTCluster;
+ Clusters.resize(1);
+ return;
+ }
+ }
+
+ // The algorithm below is not suitable for -O0.
+ if (TM->getOptLevel() == CodeGenOpt::None)
+ return;
+
+ // Split Clusters into minimum number of dense partitions. The algorithm uses
+ // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
+ // for the Case Statement'" (1994), but builds the MinPartitions array in
+ // reverse order to make it easier to reconstruct the partitions in ascending
+ // order. In the choice between two optimal partitionings, it picks the one
+ // which yields more jump tables.
+
+ // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+ SmallVector<unsigned, 8> MinPartitions(N);
+ // LastElement[i] is the last element of the partition starting at i.
+ SmallVector<unsigned, 8> LastElement(N);
+ // PartitionsScore[i] is used to break ties when choosing between two
+ // partitionings resulting in the same number of partitions.
+ SmallVector<unsigned, 8> PartitionsScore(N);
+ // For PartitionsScore, a small number of comparisons is considered as good as
+ // a jump table and a single comparison is considered better than a jump
+ // table.
+ enum PartitionScores : unsigned {
+ NoTable = 0,
+ Table = 1,
+ FewCases = 1,
+ SingleCase = 2
+ };
+
+ // Base case: There is only one way to partition Clusters[N-1].
+ MinPartitions[N - 1] = 1;
+ LastElement[N - 1] = N - 1;
+ PartitionsScore[N - 1] = PartitionScores::SingleCase;
+
+ // Note: loop indexes are signed to avoid underflow.
+ for (int64_t i = N - 2; i >= 0; i--) {
+ // Find optimal partitioning of Clusters[i..N-1].
+ // Baseline: Put Clusters[i] into a partition on its own.
+ MinPartitions[i] = MinPartitions[i + 1] + 1;
+ LastElement[i] = i;
+ PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
+
+ // Search for a solution that results in fewer partitions.
+ for (int64_t j = N - 1; j > i; j--) {
+ // Try building a partition from Clusters[i..j].
+ uint64_t Range = getJumpTableRange(Clusters, i, j);
+ uint64_t NumCases = getJumpTableNumCases(TotalCases, i, j);
+ assert(NumCases < UINT64_MAX / 100);
+ assert(Range >= NumCases);
+ if (TLI->isSuitableForJumpTable(SI, NumCases, Range)) {
+ unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+ unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
+ int64_t NumEntries = j - i + 1;
+
+ if (NumEntries == 1)
+ Score += PartitionScores::SingleCase;
+ else if (NumEntries <= SmallNumberOfEntries)
+ Score += PartitionScores::FewCases;
+ else if (NumEntries >= MinJumpTableEntries)
+ Score += PartitionScores::Table;
+
+ // If this leads to fewer partitions, or to the same number of
+ // partitions with better score, it is a better partitioning.
+ if (NumPartitions < MinPartitions[i] ||
+ (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
+ MinPartitions[i] = NumPartitions;
+ LastElement[i] = j;
+ PartitionsScore[i] = Score;
+ }
+ }
+ }
+ }
+
+ // Iterate over the partitions, replacing some with jump tables in-place.
+ unsigned DstIndex = 0;
+ for (unsigned First = 0, Last; First < N; First = Last + 1) {
+ Last = LastElement[First];
+ assert(Last >= First);
+ assert(DstIndex <= First);
+ unsigned NumClusters = Last - First + 1;
+
+ CaseCluster JTCluster;
+ if (NumClusters >= MinJumpTableEntries &&
+ buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
+ Clusters[DstIndex++] = JTCluster;
+ } else {
+ for (unsigned I = First; I <= Last; ++I)
+ std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
+ }
+ }
+ Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
+ unsigned First, unsigned Last,
+ const SwitchInst *SI,
+ MachineBasicBlock *DefaultMBB,
+ CaseCluster &JTCluster) {
+ assert(First <= Last);
+
+ auto Prob = BranchProbability::getZero();
+ unsigned NumCmps = 0;
+ std::vector<MachineBasicBlock*> Table;
+ DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
+
+ // Initialize probabilities in JTProbs.
+ for (unsigned I = First; I <= Last; ++I)
+ JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
+
+ for (unsigned I = First; I <= Last; ++I) {
+ assert(Clusters[I].Kind == CC_Range);
+ Prob += Clusters[I].Prob;
+ const APInt &Low = Clusters[I].Low->getValue();
+ const APInt &High = Clusters[I].High->getValue();
+ NumCmps += (Low == High) ? 1 : 2;
+ if (I != First) {
+ // Fill the gap between this and the previous cluster.
+ const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
+ assert(PreviousHigh.slt(Low));
+ uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
+ for (uint64_t J = 0; J < Gap; J++)
+ Table.push_back(DefaultMBB);
+ }
+ uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
+ for (uint64_t J = 0; J < ClusterSize; ++J)
+ Table.push_back(Clusters[I].MBB);
+ JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
+ }
+
+ unsigned NumDests = JTProbs.size();
+ if (TLI->isSuitableForBitTests(NumDests, NumCmps,
+ Clusters[First].Low->getValue(),
+ Clusters[Last].High->getValue(), *DL)) {
+ // Clusters[First..Last] should be lowered as bit tests instead.
+ return false;
+ }
+
+ // Create the MBB that will load from and jump through the table.
+ // Note: We create it here, but it's not inserted into the function yet.
+ MachineFunction *CurMF = FuncInfo.MF;
+ MachineBasicBlock *JumpTableMBB =
+ CurMF->CreateMachineBasicBlock(SI->getParent());
+
+ // Add successors. Note: use table order for determinism.
+ SmallPtrSet<MachineBasicBlock *, 8> Done;
+ for (MachineBasicBlock *Succ : Table) {
+ if (Done.count(Succ))
+ continue;
+ addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
+ Done.insert(Succ);
+ }
+ JumpTableMBB->normalizeSuccProbs();
+
+ unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
+ ->createJumpTableIndex(Table);
+
+ // Set up the jump table info.
+ JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
+ JumpTableHeader JTH(Clusters[First].Low->getValue(),
+ Clusters[Last].High->getValue(), SI->getCondition(),
+ nullptr, false);
+ JTCases.emplace_back(std::move(JTH), std::move(JT));
+
+ JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
+ JTCases.size() - 1, Prob);
+ return true;
+}
+
+void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
+ const SwitchInst *SI) {
+ // Partition Clusters into as few subsets as possible, where each subset has a
+ // range that fits in a machine word and has <= 3 unique destinations.
+
+#ifndef NDEBUG
+ // Clusters must be sorted and contain Range or JumpTable clusters.
+ assert(!Clusters.empty());
+ assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
+ for (const CaseCluster &C : Clusters)
+ assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
+ for (unsigned i = 1; i < Clusters.size(); ++i)
+ assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
+#endif
+
+ // The algorithm below is not suitable for -O0.
+ if (TM->getOptLevel() == CodeGenOpt::None)
+ return;
+
+ // If target does not have legal shift left, do not emit bit tests at all.
+ EVT PTy = TLI->getPointerTy(*DL);
+ if (!TLI->isOperationLegal(ISD::SHL, PTy))
+ return;
+
+ int BitWidth = PTy.getSizeInBits();
+ const int64_t N = Clusters.size();
+
+ // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
+ SmallVector<unsigned, 8> MinPartitions(N);
+ // LastElement[i] is the last element of the partition starting at i.
+ SmallVector<unsigned, 8> LastElement(N);
+
+ // FIXME: This might not be the best algorithm for finding bit test clusters.
+
+ // Base case: There is only one way to partition Clusters[N-1].
+ MinPartitions[N - 1] = 1;
+ LastElement[N - 1] = N - 1;
+
+ // Note: loop indexes are signed to avoid underflow.
+ for (int64_t i = N - 2; i >= 0; --i) {
+ // Find optimal partitioning of Clusters[i..N-1].
+ // Baseline: Put Clusters[i] into a partition on its own.
+ MinPartitions[i] = MinPartitions[i + 1] + 1;
+ LastElement[i] = i;
+
+ // Search for a solution that results in fewer partitions.
+ // Note: the search is limited by BitWidth, reducing time complexity.
+ for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
+ // Try building a partition from Clusters[i..j].
+
+ // Check the range.
+ if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
+ Clusters[j].High->getValue(), *DL))
+ continue;
+
+ // Check nbr of destinations and cluster types.
+ // FIXME: This works, but doesn't seem very efficient.
+ bool RangesOnly = true;
+ BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+ for (int64_t k = i; k <= j; k++) {
+ if (Clusters[k].Kind != CC_Range) {
+ RangesOnly = false;
+ break;
+ }
+ Dests.set(Clusters[k].MBB->getNumber());
+ }
+ if (!RangesOnly || Dests.count() > 3)
+ break;
+
+ // Check if it's a better partition.
+ unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
+ if (NumPartitions < MinPartitions[i]) {
+ // Found a better partition.
+ MinPartitions[i] = NumPartitions;
+ LastElement[i] = j;
+ }
+ }
+ }
+
+ // Iterate over the partitions, replacing with bit-test clusters in-place.
+ unsigned DstIndex = 0;
+ for (unsigned First = 0, Last; First < N; First = Last + 1) {
+ Last = LastElement[First];
+ assert(First <= Last);
+ assert(DstIndex <= First);
+
+ CaseCluster BitTestCluster;
+ if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
+ Clusters[DstIndex++] = BitTestCluster;
+ } else {
+ size_t NumClusters = Last - First + 1;
+ std::memmove(&Clusters[DstIndex], &Clusters[First],
+ sizeof(Clusters[0]) * NumClusters);
+ DstIndex += NumClusters;
+ }
+ }
+ Clusters.resize(DstIndex);
+}
+
+bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
+ unsigned First, unsigned Last,
+ const SwitchInst *SI,
+ CaseCluster &BTCluster) {
+ assert(First <= Last);
+ if (First == Last)
+ return false;
+
+ BitVector Dests(FuncInfo.MF->getNumBlockIDs());
+ unsigned NumCmps = 0;
+ for (int64_t I = First; I <= Last; ++I) {
+ assert(Clusters[I].Kind == CC_Range);
+ Dests.set(Clusters[I].MBB->getNumber());
+ NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
+ }
+ unsigned NumDests = Dests.count();
+
+ APInt Low = Clusters[First].Low->getValue();
+ APInt High = Clusters[Last].High->getValue();
+ assert(Low.slt(High));
+
+ if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
+ return false;
+
+ APInt LowBound;
+ APInt CmpRange;
+
+ const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
+ assert(TLI->rangeFitsInWord(Low, High, *DL) &&
+ "Case range must fit in bit mask!");
+
+ // Check if the clusters cover a contiguous range such that no value in the
+ // range will jump to the default statement.
+ bool ContiguousRange = true;
+ for (int64_t I = First + 1; I <= Last; ++I) {
+ if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
+ ContiguousRange = false;
+ break;
+ }
+ }
+
+ if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
+ // Optimize the case where all the case values fit in a word without having
+ // to subtract minValue. In this case, we can optimize away the subtraction.
+ LowBound = APInt::getNullValue(Low.getBitWidth());
+ CmpRange = High;
+ ContiguousRange = false;
+ } else {
+ LowBound = Low;
+ CmpRange = High - Low;
+ }
+
+ CaseBitsVector CBV;
+ auto TotalProb = BranchProbability::getZero();
+ for (unsigned i = First; i <= Last; ++i) {
+ // Find the CaseBits for this destination.
+ unsigned j;
+ for (j = 0; j < CBV.size(); ++j)
+ if (CBV[j].BB == Clusters[i].MBB)
+ break;
+ if (j == CBV.size())
+ CBV.push_back(
+ CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
+ CaseBits *CB = &CBV[j];
+
+ // Update Mask, Bits and ExtraProb.
+ uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
+ uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
+ assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
+ CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
+ CB->Bits += Hi - Lo + 1;
+ CB->ExtraProb += Clusters[i].Prob;
+ TotalProb += Clusters[i].Prob;
+ }
+
+ BitTestInfo BTI;
+ llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
+ // Sort by probability first, number of bits second, bit mask third.
+ if (a.ExtraProb != b.ExtraProb)
+ return a.ExtraProb > b.ExtraProb;
+ if (a.Bits != b.Bits)
+ return a.Bits > b.Bits;
+ return a.Mask < b.Mask;
+ });
+
+ for (auto &CB : CBV) {
+ MachineBasicBlock *BitTestBB =
+ FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
+ BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
+ }
+ BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
+ SI->getCondition(), -1U, MVT::Other, false,
+ ContiguousRange, nullptr, nullptr, std::move(BTI),
+ TotalProb);
+
+ BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
+ BitTestCases.size() - 1, TotalProb);
+ return true;
+}
+
+void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
+#ifndef NDEBUG
+ for (const CaseCluster &CC : Clusters)
+ assert(CC.Low == CC.High && "Input clusters must be single-case");
+#endif
+
+ llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
+ return a.Low->getValue().slt(b.Low->getValue());
+ });
+
+ // Merge adjacent clusters with the same destination.
+ const unsigned N = Clusters.size();
+ unsigned DstIndex = 0;
+ for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
+ CaseCluster &CC = Clusters[SrcIndex];
+ const ConstantInt *CaseVal = CC.Low;
+ MachineBasicBlock *Succ = CC.MBB;
+
+ if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
+ (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
+ // If this case has the same successor and is a neighbour, merge it into
+ // the previous cluster.
+ Clusters[DstIndex - 1].High = CaseVal;
+ Clusters[DstIndex - 1].Prob += CC.Prob;
+ } else {
+ std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
+ sizeof(Clusters[SrcIndex]));
+ }
+ }
+ Clusters.resize(DstIndex);
+}