OSDN Git Service

CodeExtractor : Add ability to preserve profile data.
authorSean Silva <chisophugis@gmail.com>
Mon, 1 Aug 2016 02:59:26 +0000 (02:59 +0000)
committerSean Silva <chisophugis@gmail.com>
Mon, 1 Aug 2016 02:59:26 +0000 (02:59 +0000)
Added ability to estimate the entry count of the extracted function and
the branch probabilities of the exit branches.

Patch by River Riddle!

Differential Revision: https://reviews.llvm.org/D22744

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@277313 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/BlockFrequencyInfo.h
include/llvm/Analysis/BlockFrequencyInfoImpl.h
include/llvm/CodeGen/MachineBlockFrequencyInfo.h
include/llvm/Transforms/Utils/CodeExtractor.h
lib/Analysis/BlockFrequencyInfo.cpp
lib/Analysis/BlockFrequencyInfoImpl.cpp
lib/CodeGen/MachineBlockFrequencyInfo.cpp
lib/Transforms/IPO/PartialInlining.cpp
lib/Transforms/Utils/CodeExtractor.cpp

index 7d48dfc..d7e76a1 100644 (file)
@@ -61,6 +61,11 @@ public:
   /// the enclosing function's count (if available) and returns the value.
   Optional<uint64_t> getBlockProfileCount(const BasicBlock *BB) const;
 
+  /// \brief Returns the estimated profile count of \p Freq.
+  /// This uses the frequency \p Freq and multiplies it by
+  /// the enclosing function's count (if available) and returns the value.
+  Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
+
   // Set the frequency of the given basic block.
   void setBlockFreq(const BasicBlock *BB, uint64_t Freq);
 
index 7ed06b1..8a0ced9 100644 (file)
@@ -482,6 +482,8 @@ public:
   BlockFrequency getBlockFreq(const BlockNode &Node) const;
   Optional<uint64_t> getBlockProfileCount(const Function &F,
                                           const BlockNode &Node) const;
+  Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+                                             uint64_t Freq) const;
 
   void setBlockFreq(const BlockNode &Node, uint64_t Freq);
 
@@ -925,6 +927,10 @@ public:
                                           const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getBlockProfileCount(F, getNode(BB));
   }
+  Optional<uint64_t> getProfileCountFromFreq(const Function &F,
+                                             uint64_t Freq) const {
+    return BlockFrequencyInfoImplBase::getProfileCountFromFreq(F, Freq);
+  }
   void setBlockFreq(const BlockT *BB, uint64_t Freq);
   Scaled64 getFloatingBlockFreq(const BlockT *BB) const {
     return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB));
index 7a23608..bfa5bf6 100644 (file)
@@ -52,6 +52,7 @@ public:
   BlockFrequency getBlockFreq(const MachineBasicBlock *MBB) const;
 
   Optional<uint64_t> getBlockProfileCount(const MachineBasicBlock *MBB) const;
+  Optional<uint64_t> getProfileCountFromFreq(uint64_t Freq) const;
 
   const MachineFunction *getFunction() const;
   const MachineBranchProbabilityInfo *getMBPI() const;
index 970b3a1..a297866 100644 (file)
@@ -20,6 +20,9 @@
 namespace llvm {
 template <typename T> class ArrayRef;
   class BasicBlock;
+  class BlockFrequency;
+  class BlockFrequencyInfo;
+  class BranchProbabilityInfo;
   class DominatorTree;
   class Function;
   class Loop;
@@ -47,6 +50,8 @@ template <typename T> class ArrayRef;
     // Various bits of state computed on construction.
     DominatorTree *const DT;
     const bool AggregateArgs;
+    BlockFrequencyInfo *BFI;
+    BranchProbabilityInfo *BPI;
 
     // Bits of intermediate state computed at various phases of extraction.
     SetVector<BasicBlock *> Blocks;
@@ -64,7 +69,9 @@ template <typename T> class ArrayRef;
     ///
     /// In this formation, we don't require a dominator tree. The given basic
     /// block is set up for extraction.
-    CodeExtractor(BasicBlock *BB, bool AggregateArgs = false);
+    CodeExtractor(BasicBlock *BB, bool AggregateArgs = false,
+                  BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a sequence of blocks.
     ///
@@ -73,20 +80,24 @@ template <typename T> class ArrayRef;
     /// sequence out into its new function. When a DominatorTree is also given,
     /// extra checking and transformations are enabled.
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
-                  bool AggregateArgs = false);
+                  bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a loop body.
     ///
     /// Behaves just like the generic code sequence constructor, but uses the
     /// block sequence of the loop.
-    CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false);
+    CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs = false,
+                  BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Create a code extractor for a region node.
     ///
     /// Behaves just like the generic code sequence constructor, but uses the
     /// block sequence of the region node passed in.
     CodeExtractor(DominatorTree &DT, const RegionNode &RN,
-                  bool AggregateArgs = false);
+                  bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
+                  BranchProbabilityInfo *BPI = nullptr);
 
     /// \brief Perform the extraction, returning the new function.
     ///
@@ -122,6 +133,11 @@ template <typename T> class ArrayRef;
 
     void moveCodeToFunction(Function *newFunction);
 
+    void calculateNewCallTerminatorWeights(
+        BasicBlock *CodeReplacer,
+        DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+        BranchProbabilityInfo *BPI);
+
     void emitCallAndSwitchStatement(Function *newFunction,
                                     BasicBlock *newHeader,
                                     ValueSet &inputs,
index 1dd8f4f..5f7060a 100644 (file)
@@ -162,6 +162,13 @@ BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB) const {
   return BFI->getBlockProfileCount(*getFunction(), BB);
 }
 
+Optional<uint64_t>
+BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+  if (!BFI)
+    return None;
+  return BFI->getProfileCountFromFreq(*getFunction(), Freq);
+}
+
 void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) {
   assert(BFI && "Expected analysis to be available");
   BFI->setBlockFreq(BB, Freq);
index 90bc249..77fe72d 100644 (file)
@@ -533,12 +533,18 @@ BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const {
 Optional<uint64_t>
 BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F,
                                                  const BlockNode &Node) const {
+  return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency());
+}
+
+Optional<uint64_t>
+BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F,
+                                                    uint64_t Freq) const {
   auto EntryCount = F.getEntryCount();
   if (!EntryCount)
     return None;
   // Use 128 bit APInt to do the arithmetic to avoid overflow.
   APInt BlockCount(128, EntryCount.getValue());
-  APInt BlockFreq(128, getBlockFreq(Node).getFrequency());
+  APInt BlockFreq(128, Freq);
   APInt EntryFreq(128, getEntryFreq());
   BlockCount *= BlockFreq;
   BlockCount = BlockCount.udiv(EntryFreq);
index 6c0f99f..faf9ecc 100644 (file)
@@ -175,6 +175,12 @@ Optional<uint64_t> MachineBlockFrequencyInfo::getBlockProfileCount(
   return MBFI ? MBFI->getBlockProfileCount(*F, MBB) : None;
 }
 
+Optional<uint64_t>
+MachineBlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const {
+  const Function *F = MBFI->getFunction()->getFunction();
+  return MBFI ? MBFI->getProfileCountFromFreq(*F, Freq) : None;
+}
+
 const MachineFunction *MachineBlockFrequencyInfo::getFunction() const {
   return MBFI ? MBFI->getFunction() : nullptr;
 }
index 6c762e4..59c83b4 100644 (file)
@@ -14,6 +14,8 @@
 
 #include "llvm/Transforms/IPO/PartialInlining.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
@@ -29,13 +31,18 @@ using namespace llvm;
 STATISTIC(NumPartialInlined, "Number of functions partially inlined");
 
 namespace {
+typedef std::function<std::pair<BlockFrequencyInfo *, BranchProbabilityInfo *>(
+    Function &)>
+    GetProfileDataFn;
 struct PartialInlinerImpl {
-  PartialInlinerImpl(InlineFunctionInfo IFI) : IFI(IFI) {}
+  PartialInlinerImpl(InlineFunctionInfo IFI, GetProfileDataFn GetProfileInfo)
+      : IFI(IFI), GetProfileInfo(GetProfileInfo) {}
   bool run(Module &M);
   Function *unswitchFunction(Function *F);
 
 private:
   InlineFunctionInfo IFI;
+  GetProfileDataFn GetProfileInfo;
 };
 struct PartialInlinerLegacyPass : public ModulePass {
   static char ID; // Pass identification, replacement for typeid
@@ -45,6 +52,8 @@ struct PartialInlinerLegacyPass : public ModulePass {
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<AssumptionCacheTracker>();
+    AU.addRequired<BlockFrequencyInfoWrapperPass>();
+    AU.addRequired<BranchProbabilityInfoWrapperPass>();
   }
   bool runOnModule(Module &M) override {
     if (skipModule(M))
@@ -55,8 +64,14 @@ struct PartialInlinerLegacyPass : public ModulePass {
         [&ACT](Function &F) -> AssumptionCache & {
       return ACT->getAssumptionCache(F);
     };
+    GetProfileDataFn GetProfileData = [this](Function &F)
+        -> std::pair<BlockFrequencyInfo *, BranchProbabilityInfo *> {
+      auto *BFI = &getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+      auto *BPI = &getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
+      return std::make_pair(BFI, BPI);
+    };
     InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
-    return PartialInlinerImpl(IFI).run(M);
+    return PartialInlinerImpl(IFI, GetProfileData).run(M);
   }
 };
 }
@@ -133,9 +148,13 @@ Function *PartialInlinerImpl::unswitchFunction(Function *F) {
   DominatorTree DT;
   DT.recalculate(*DuplicateFunction);
 
+  auto ProfileInfo = GetProfileInfo(*DuplicateFunction);
+
   // Extract the body of the if.
   Function *ExtractedFunction =
-      CodeExtractor(ToExtract, &DT).extractCodeRegion();
+      CodeExtractor(ToExtract, &DT, /*AggregateArgs*/false, ProfileInfo.first,
+                    ProfileInfo.second)
+          .extractCodeRegion();
 
   // Inline the top-level if test into all callers.
   std::vector<User *> Users(DuplicateFunction->user_begin(),
@@ -181,8 +200,8 @@ bool PartialInlinerImpl::run(Module &M) {
     if (Recursive)
       continue;
 
-    if (Function *newFunc = unswitchFunction(CurrFunc)) {
-      Worklist.push_back(newFunc);
+    if (Function *NewFunc = unswitchFunction(CurrFunc)) {
+      Worklist.push_back(NewFunc);
       Changed = true;
     }
   }
@@ -194,6 +213,8 @@ char PartialInlinerLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner",
                       "Partial Inliner", false, false)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
 INITIALIZE_PASS_END(PartialInlinerLegacyPass, "partial-inliner",
                     "Partial Inliner", false, false)
 
@@ -208,8 +229,14 @@ PreservedAnalyses PartialInlinerPass::run(Module &M,
       [&FAM](Function &F) -> AssumptionCache & {
     return FAM.getResult<AssumptionAnalysis>(F);
   };
+  GetProfileDataFn GetProfileData = [&FAM](
+      Function &F) -> std::pair<BlockFrequencyInfo *, BranchProbabilityInfo *> {
+    auto *BFI = &FAM.getResult<BlockFrequencyAnalysis>(F);
+    auto *BPI = &FAM.getResult<BranchProbabilityAnalysis>(F);
+    return std::make_pair(BFI, BPI);
+  };
   InlineFunctionInfo IFI(nullptr, &GetAssumptionCache);
-  if (PartialInlinerImpl(IFI).run(M))
+  if (PartialInlinerImpl(IFI, GetProfileData).run(M))
     return PreservedAnalyses::none();
   return PreservedAnalyses::all();
 }
index 772403e..6816eb4 100644 (file)
@@ -17,6 +17,9 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/RegionInfo.h"
 #include "llvm/Analysis/RegionIterator.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/BlockFrequency.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -119,23 +124,30 @@ buildExtractionBlockSet(const RegionNode &RN) {
   return buildExtractionBlockSet(R.block_begin(), R.block_end());
 }
 
-CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
-  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
-                             bool AggregateArgs)
-  : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
-
-CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
+
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
+      NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
-                             bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -672,6 +684,51 @@ void CodeExtractor::moveCodeToFunction(Function *newFunction) {
   }
 }
 
+void CodeExtractor::calculateNewCallTerminatorWeights(
+    BasicBlock *CodeReplacer,
+    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+    BranchProbabilityInfo *BPI) {
+  typedef BlockFrequencyInfoImplBase::Distribution Distribution;
+  typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
+
+  // Update the branch weights for the exit block.
+  TerminatorInst *TI = CodeReplacer->getTerminator();
+  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+  // Block Frequency distribution with dummy node.
+  Distribution BranchDist;
+
+  // Add each of the frequencies of the successors.
+  for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+    BlockNode ExitNode(i);
+    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+    if (ExitFreq != 0)
+      BranchDist.addExit(ExitNode, ExitFreq);
+    else
+      BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
+  }
+
+  // Check for no total weight.
+  if (BranchDist.Total == 0)
+    return;
+
+  // Normalize the distribution so that they can fit in unsigned.
+  BranchDist.normalize();
+
+  // Create normalized branch weights and set the metadata.
+  for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+    const auto &Weight = BranchDist.Weights[I];
+
+    // Get the weight and update the current BFI.
+    BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+    BranchProbability BP(Weight.Amount, BranchDist.Total);
+    BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
+  }
+  TI->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
 Function *CodeExtractor::extractCodeRegion() {
   if (!isEligible())
     return nullptr;
@@ -682,6 +739,19 @@ Function *CodeExtractor::extractCodeRegion() {
   // block in the region.
   BasicBlock *header = *Blocks.begin();
 
+  // Calculate the entry frequency of the new function before we change the root
+  //   block.
+  BlockFrequency EntryFreq;
+  if (BFI) {
+    assert(BPI && "Both BPI and BFI are required to preserve profile info");
+    for (BasicBlock *Pred : predecessors(header)) {
+      if (Blocks.count(Pred))
+        continue;
+      EntryFreq +=
+          BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+    }
+  }
+
   // If we have to split PHI nodes or the entry block, do so now.
   severSplitPHINodes(header);
 
@@ -705,12 +775,23 @@ Function *CodeExtractor::extractCodeRegion() {
   // Find inputs to, outputs from the code region.
   findInputsOutputs(inputs, outputs);
 
+  // Calculate the exit blocks for the extracted region and the total exit
+  //  weights for each of those blocks.
+  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
   SmallPtrSet<BasicBlock *, 1> ExitBlocks;
-  for (BasicBlock *Block : Blocks)
+  for (BasicBlock *Block : Blocks) {
     for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
-         ++SI)
-      if (!Blocks.count(*SI))
+         ++SI) {
+      if (!Blocks.count(*SI)) {
+        // Update the branch weight for this successor.
+        if (BFI) {
+          BlockFrequency &BF = ExitWeights[*SI];
+          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+        }
         ExitBlocks.insert(*SI);
+      }
+    }
+  }
   NumExitBlocks = ExitBlocks.size();
 
   // Construct new function based on inputs/outputs & add allocas for all defs.
@@ -719,10 +800,23 @@ Function *CodeExtractor::extractCodeRegion() {
                                             codeReplacer, oldFunction,
                                             oldFunction->getParent());
 
+  // Update the entry count of the function.
+  if (BFI) {
+    Optional<uint64_t> EntryCount =
+        BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+    if (EntryCount.hasValue())
+      newFunction->setEntryCount(EntryCount.getValue());
+    BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+  }
+
   emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
 
   moveCodeToFunction(newFunction);
 
+  // Update the branch weights for the exit block.
+  if (BFI && NumExitBlocks > 1)
+    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
   // Loop over all of the PHI nodes in the header block, and change any
   // references to the old incoming edge to be the new incoming edge.
   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {