OSDN Git Service

Extend CFGPrinter and CallPrinter with Heat Colors
authorSean Fertile <sfertile@ca.ibm.com>
Fri, 29 Jun 2018 17:13:58 +0000 (17:13 +0000)
committerSean Fertile <sfertile@ca.ibm.com>
Fri, 29 Jun 2018 17:13:58 +0000 (17:13 +0000)
Extends the CFGPrinter and CallPrinter with heat colors based on heuristics or
profiling information. The colors are enabled by default and can be toggled
on/off for CFGPrinter by using the option -cfg-heat-colors for both
-dot-cfg[-only] and -view-cfg[-only].  Similarly, the colors can be toggled
on/off for CallPrinter by using the option -callgraph-heat-colors for both
-dot-callgraph and -view-callgraph.

Patch by Rodrigo Caetano Rocha!

Differential Revision: https://reviews.llvm.org/D40425

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@335996 91177308-0d34-0410-b5e6-96231b3b80d8

12 files changed:
include/llvm/Analysis/CFGPrinter.h
include/llvm/Analysis/HeatUtils.h [new file with mode: 0644]
lib/Analysis/CFGPrinter.cpp
lib/Analysis/CMakeLists.txt
lib/Analysis/CallPrinter.cpp
lib/Analysis/DomPrinter.cpp
lib/Analysis/HeatUtils.cpp [new file with mode: 0644]
lib/Analysis/RegionPrinter.cpp
lib/Passes/PassRegistry.def
lib/Transforms/Scalar/NewGVN.cpp
llvm/Analysis/HeatUtils.h [new file with mode: 0644]
test/Other/2007-06-05-PassID.ll

index 5786769..b9b8994 100644 (file)
 
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/Module.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/HeatUtils.h"
+#include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/GraphWriter.h"
 
 namespace llvm {
 class CFGViewerPass
     : public PassInfoMixin<CFGViewerPass> {
 public:
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
 class CFGOnlyViewerPass
     : public PassInfoMixin<CFGOnlyViewerPass> {
 public:
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
 class CFGPrinterPass
     : public PassInfoMixin<CFGPrinterPass> {
 public:
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
 class CFGOnlyPrinterPass
     : public PassInfoMixin<CFGOnlyPrinterPass> {
 public:
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
 };
 
-template<>
-struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits {
+class CFGDOTInfo {
+private:
+  const Function *F;
+  const BlockFrequencyInfo *BFI;
+  const BranchProbabilityInfo *BPI;
+  uint64_t MaxFreq;
+  bool ShowHeat;
+  bool Heuristic;
+  bool EdgeWeights;
+  bool RawWeights;
 
-  DOTGraphTraits (bool isSimple=false) : DefaultDOTGraphTraits(isSimple) {}
+public:
+  CFGDOTInfo(const Function *F) : CFGDOTInfo(F, nullptr, nullptr, 0) { }
+
+  CFGDOTInfo(const Function *F, const BlockFrequencyInfo *BFI,
+             BranchProbabilityInfo *BPI, uint64_t MaxFreq) 
+      : F(F), BFI(BFI), BPI(BPI), MaxFreq(MaxFreq) {
+    ShowHeat = false;
+    Heuristic = true;
+    EdgeWeights = true;
+    RawWeights = true;
+  }
+
+  const BlockFrequencyInfo *getBFI() { return BFI; }
+
+  const BranchProbabilityInfo *getBPI() { return BPI; }
+
+  const Function *getFunction() { return this->F; }
+
+  uint64_t getMaxFreq() { return MaxFreq; }
 
-  static std::string getGraphName(const Function *F) {
-    return "CFG for '" + F->getName().str() + "' function";
+  uint64_t getFreq(const BasicBlock *BB) {
+    return getBlockFreq(BB, BFI, Heuristic);
   }
 
-  static std::string getSimpleNodeLabel(const BasicBlock *Node,
-                                        const Function *) {
+  void setHeatColors(bool ShowHeat) { this->ShowHeat = ShowHeat; }
+
+  bool showHeatColors() { return ShowHeat; }
+
+  void setHeuristic(bool Heuristic) { this->Heuristic = Heuristic; }
+
+  bool useHeuristic() { return Heuristic; }
+
+  void setRawEdgeWeights(bool RawWeights) { this->RawWeights = RawWeights; }
+
+  bool useRawEdgeWeights() { return RawWeights; }
+
+  void setEdgeWeights(bool EdgeWeights) { this->EdgeWeights = EdgeWeights; }
+
+  bool showEdgeWeights() { return EdgeWeights; }
+};
+
+template <>
+struct GraphTraits<CFGDOTInfo *> : public GraphTraits<const BasicBlock *> {
+  static NodeRef getEntryNode(CFGDOTInfo *CFGInfo) {
+    return &(CFGInfo->getFunction()->getEntryBlock());
+  }
+
+  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
+  using nodes_iterator = pointer_iterator<Function::const_iterator>;
+
+  static nodes_iterator nodes_begin(CFGDOTInfo *CFGInfo) {
+    return nodes_iterator(CFGInfo->getFunction()->begin());
+  }
+
+  static nodes_iterator nodes_end(CFGDOTInfo *CFGInfo) {
+    return nodes_iterator(CFGInfo->getFunction()->end());
+  }
+
+  static size_t size(CFGDOTInfo *CFGInfo) { return CFGInfo->getFunction()->size(); }
+};
+
+template <> struct DOTGraphTraits<CFGDOTInfo *> : public DefaultDOTGraphTraits {
+
+  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
+
+  static std::string getGraphName(CFGDOTInfo *CFGInfo) {
+    return "CFG for '" + CFGInfo->getFunction()->getName().str() + "' function";
+  }
+
+  static std::string getSimpleNodeLabel(const BasicBlock *Node, CFGDOTInfo *) {
     if (!Node->getName().empty())
       return Node->getName().str();
 
@@ -73,7 +148,7 @@ struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits {
   }
 
   static std::string getCompleteNodeLabel(const BasicBlock *Node,
-                                          const Function *) {
+                                          CFGDOTInfo *) {
     enum { MaxColumns = 80 };
     std::string Str;
     raw_string_ostream OS(Str);
@@ -117,12 +192,11 @@ struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits {
     return OutStr;
   }
 
-  std::string getNodeLabel(const BasicBlock *Node,
-                           const Function *Graph) {
+  std::string getNodeLabel(const BasicBlock *Node, CFGDOTInfo *CFGInfo) {
     if (isSimple())
-      return getSimpleNodeLabel(Node, Graph);
+      return getSimpleNodeLabel(Node, CFGInfo);
     else
-      return getCompleteNodeLabel(Node, Graph);
+      return getCompleteNodeLabel(Node, CFGInfo);
   }
 
   static std::string getEdgeSourceLabel(const BasicBlock *Node,
@@ -149,39 +223,86 @@ struct DOTGraphTraits<const Function*> : public DefaultDOTGraphTraits {
 
   /// Display the raw branch weights from PGO.
   std::string getEdgeAttributes(const BasicBlock *Node, succ_const_iterator I,
-                                const Function *F) {
+                                CFGDOTInfo *CFGInfo) {
+
+    if (!CFGInfo->showEdgeWeights())
+      return "";
+
+    const unsigned MaxEdgeWidth = 2;
+
     const TerminatorInst *TI = Node->getTerminator();
     if (TI->getNumSuccessors() == 1)
-      return "";
+      return "penwidth="+std::to_string(MaxEdgeWidth);
 
-    MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
-    if (!WeightsNode)
-      return "";
+    unsigned OpNo = I.getSuccessorIndex();
 
-    MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
-    if (MDName->getString() != "branch_weights")
+    if (OpNo >= TI->getNumSuccessors())
       return "";
 
-    unsigned OpNo = I.getSuccessorIndex() + 1;
-    if (OpNo >= WeightsNode->getNumOperands())
-      return "";
-    ConstantInt *Weight =
-        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(OpNo));
-    if (!Weight)
+    std::string Attrs = "";
+
+    BasicBlock *SuccBB = TI->getSuccessor(OpNo);
+    auto BranchProb = CFGInfo->getBPI()->getEdgeProbability(Node,SuccBB);
+    double WeightPercent = ((double)BranchProb.getNumerator()) /
+                           ((double)BranchProb.getDenominator());
+    double Width = 1+(MaxEdgeWidth-1)*WeightPercent;
+
+    if (CFGInfo->useRawEdgeWeights()) {
+      // Prepend a 'W' to indicate that this is a weight rather than the actual
+      // profile count (due to scaling).
+
+      uint64_t Freq = CFGInfo->getFreq(Node);
+      Attrs = formatv("label=\"W:{0}\" penwidth={1}", (uint64_t)(Freq*WeightPercent), Width);
+      if (Attrs.size())
+        return Attrs;
+
+      MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
+      if (!WeightsNode)
+        return Attrs;
+
+      MDString *MDName = cast<MDString>(WeightsNode->getOperand(0));
+      if (MDName->getString() != "branch_weights")
+        return Attrs;
+
+      unsigned OpNo = I.getSuccessorIndex() + 1;
+      if (OpNo >= WeightsNode->getNumOperands())
+        return Attrs;
+      ConstantInt *Weight =
+          mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(OpNo));
+      if (!Weight)
+        return Attrs;
+
+      Attrs = "label=\"W:" + std::to_string(Weight->getZExtValue()) + "\" penwidth=" + std::to_string(Width);
+    } else {
+      //formatting value to percentage
+      Attrs = formatv("label=\"{0:P}\" penwidth={1}", WeightPercent, Width);
+    }
+    return Attrs;
+  }
+
+  std::string getNodeAttributes(const BasicBlock *Node, CFGDOTInfo *CFGInfo) {
+
+    if (!CFGInfo->showHeatColors())
       return "";
 
-    // Prepend a 'W' to indicate that this is a weight rather than the actual
-    // profile count (due to scaling).
-    Twine Attrs = "label=\"W:" + Twine(Weight->getZExtValue()) + "\"";
-    return Attrs.str();
+    uint64_t Freq = CFGInfo->getFreq(Node);
+    std::string Color = getHeatColor(Freq, CFGInfo->getMaxFreq());
+    std::string EdgeColor = (Freq <= (CFGInfo->getMaxFreq() / 2))
+                             ? (getHeatColor(0))
+                             : (getHeatColor(1));
+
+    std::string Attrs = "color=\"" + EdgeColor + "ff\", style=filled,"+
+                        "fillcolor=\"" + Color + "70\"";
+    return Attrs;
   }
 };
-} // End llvm namespace
+
+} // namespace llvm
 
 namespace llvm {
-  class FunctionPass;
-  FunctionPass *createCFGPrinterLegacyPassPass ();
-  FunctionPass *createCFGOnlyPrinterLegacyPassPass ();
-} // End llvm namespace
+class ModulePass;
+ModulePass *createCFGPrinterLegacyPassPass();
+ModulePass *createCFGOnlyPrinterLegacyPassPass();
+} // namespace llvm
 
 #endif
diff --git a/include/llvm/Analysis/HeatUtils.h b/include/llvm/Analysis/HeatUtils.h
new file mode 100644 (file)
index 0000000..8cb03b9
--- /dev/null
@@ -0,0 +1,54 @@
+//===-- HeatUtils.h - Utility for printing heat colors ----------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Utility for printing heat colors based on heuristics or profiling
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_HEATUTILS_H
+#define LLVM_ANALYSIS_HEATUTILS_H
+
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/CallSite.h"
+
+#include <string>
+
+namespace llvm {
+
+bool hasProfiling(const Module &M);
+
+uint64_t getBlockFreq(const BasicBlock *BB, const BlockFrequencyInfo *BFI,
+                      bool useHeuristic = true);
+
+uint64_t getNumOfCalls(Function &callerFunction, Function &calledFunction,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic = true);
+
+uint64_t getNumOfCalls(CallSite &callsite,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic = true);
+
+uint64_t getMaxFreq(const Function &F, const BlockFrequencyInfo *BFI,
+                    bool useHeuristic = true);
+
+uint64_t getMaxFreq(Module &M,
+                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                    bool useHeuristic = true);
+
+std::string getHeatColor(uint64_t freq, uint64_t maxFreq);
+
+std::string getHeatColor(double percent);
+
+} // namespace llvm
+
+#endif
index fb26175..10be4e4 100644 (file)
 #include "llvm/Support/FileSystem.h"
 using namespace llvm;
 
+static cl::opt<bool> CFGHeatPerFunction("cfg-heat-per-function",
+                                        cl::init(false), cl::Hidden,
+                                        cl::desc("Heat CFG per function"));
+
+static cl::opt<bool> ShowHeatColors("cfg-heat-colors", cl::init(true),
+                                    cl::Hidden,
+                                    cl::desc("Show heat colors in CFG"));
+
+static cl::opt<bool> UseRawEdgeWeight("cfg-raw-weights", cl::init(false),
+                                      cl::Hidden,
+                                      cl::desc("Use raw weights for labels. "
+                                               "Use percentages as default."));
+
+static cl::opt<bool> ShowEdgeWeight("cfg-weights", cl::init(true), cl::Hidden,
+                                    cl::desc("Show edges labeled with weights"));
+
+static void writeHeatCFGToDotFile(Function &F, BlockFrequencyInfo *BFI,
+                                 BranchProbabilityInfo *BPI, uint64_t MaxFreq,
+                                 bool UseHeuristic, bool isSimple) {
+  std::string Filename = ("cfg." + F.getName() + ".dot").str();
+  errs() << "Writing '" << Filename << "'...";
+
+  std::error_code EC;
+  raw_fd_ostream File(Filename, EC, sys::fs::F_Text);
+
+  CFGDOTInfo CFGInfo(&F, BFI, BPI, MaxFreq);
+  CFGInfo.setHeuristic(UseHeuristic);
+  CFGInfo.setHeatColors(ShowHeatColors);
+  CFGInfo.setEdgeWeights(ShowEdgeWeight);
+  CFGInfo.setRawEdgeWeights(UseRawEdgeWeight);
+
+  if (!EC)
+    WriteGraph(File, &CFGInfo, isSimple);
+  else
+    errs() << "  error opening file for writing!";
+  errs() << "\n";
+}
+
+static void writeAllCFGsToDotFile(Module &M,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
+                       bool isSimple) {
+  bool UseHeuristic = true;
+  uint64_t MaxFreq = 0;
+  if (!CFGHeatPerFunction)
+    MaxFreq = getMaxFreq(M, LookupBFI, UseHeuristic);
+
+  for (auto &F : M) {
+    if (F.isDeclaration()) continue;
+    auto *BFI = LookupBFI(F);
+    auto *BPI = LookupBPI(F);
+    if (CFGHeatPerFunction)
+      MaxFreq = getMaxFreq(F, BFI, UseHeuristic);
+    writeHeatCFGToDotFile(F, BFI, BPI, MaxFreq, UseHeuristic, isSimple);
+  }
+
+}
+
+static void viewHeatCFG(Function &F, BlockFrequencyInfo *BFI,
+                                 BranchProbabilityInfo *BPI, uint64_t MaxFreq,
+                                 bool UseHeuristic, bool isSimple) {
+  CFGDOTInfo CFGInfo(&F, BFI, BPI, MaxFreq);
+  CFGInfo.setHeuristic(UseHeuristic);
+  CFGInfo.setHeatColors(ShowHeatColors);
+  CFGInfo.setEdgeWeights(ShowEdgeWeight);
+  CFGInfo.setRawEdgeWeights(UseRawEdgeWeight);
+
+  ViewGraph(&CFGInfo, "cfg." + F.getName(), isSimple);
+}
+
+static void viewAllCFGs(Module &M,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
+                       bool isSimple) {
+  bool UseHeuristic = true;
+  uint64_t MaxFreq = 0;
+  if (!CFGHeatPerFunction)
+    MaxFreq = getMaxFreq(M, LookupBFI, UseHeuristic);
+
+  for (auto &F : M) {
+    if (F.isDeclaration()) continue;
+    auto *BFI = LookupBFI(F);
+    auto *BPI = LookupBPI(F);
+    if (CFGHeatPerFunction)
+      MaxFreq = getMaxFreq(F, BFI, UseHeuristic);
+    viewHeatCFG(F, BFI, BPI, MaxFreq, UseHeuristic, isSimple);
+  }
+
+}
+
 namespace {
-  struct CFGViewerLegacyPass : public FunctionPass {
+  struct CFGViewerLegacyPass : public ModulePass {
     static char ID; // Pass identifcation, replacement for typeid
-    CFGViewerLegacyPass() : FunctionPass(ID) {
+    CFGViewerLegacyPass() : ModulePass(ID) {
       initializeCFGViewerLegacyPassPass(*PassRegistry::getPassRegistry());
     }
 
-    bool runOnFunction(Function &F) override {
-      F.viewCFG();
+    bool runOnModule(Module &M) override {
+      auto LookupBFI = [this](Function &F) {
+        return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+      };
+      auto LookupBPI = [this](Function &F) {
+        return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
+      };
+      viewAllCFGs(M, LookupBFI, LookupBPI, /*isSimple=*/false);
       return false;
     }
 
-    void print(raw_ostream &OS, const Module* = nullptr) const override {}
+    void print(raw_ostream &OS, const Module * = nullptr) const override {}
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      ModulePass::getAnalysisUsage(AU);
+      AU.addRequired<BlockFrequencyInfoWrapperPass>();
+      AU.addRequired<BranchProbabilityInfoWrapperPass>();
       AU.setPreservesAll();
     }
+
   };
 }
 
 char CFGViewerLegacyPass::ID = 0;
 INITIALIZE_PASS(CFGViewerLegacyPass, "view-cfg", "View CFG of function", false, true)
 
-PreservedAnalyses CFGViewerPass::run(Function &F,
-                                     FunctionAnalysisManager &AM) {
-  F.viewCFG();
+PreservedAnalyses CFGViewerPass::run(Module &M,
+                                     ModuleAnalysisManager &AM) {
+  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto LookupBFI = [&FAM](Function &F) {
+    return &FAM.getResult<BlockFrequencyAnalysis>(F);
+  };
+  auto LookupBPI = [&FAM](Function &F) {
+    return &FAM.getResult<BranchProbabilityAnalysis>(F);
+  };
+  viewAllCFGs(M, LookupBFI, LookupBPI, /*isSimple=*/false);
   return PreservedAnalyses::all();
 }
 
 
 namespace {
-  struct CFGOnlyViewerLegacyPass : public FunctionPass {
+  struct CFGOnlyViewerLegacyPass : public ModulePass {
     static char ID; // Pass identifcation, replacement for typeid
-    CFGOnlyViewerLegacyPass() : FunctionPass(ID) {
+    CFGOnlyViewerLegacyPass() : ModulePass(ID) {
       initializeCFGOnlyViewerLegacyPassPass(*PassRegistry::getPassRegistry());
     }
 
-    bool runOnFunction(Function &F) override {
-      F.viewCFGOnly();
+    bool runOnModule(Module &M) override {
+      auto LookupBFI = [this](Function &F) {
+        return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+      };
+      auto LookupBPI = [this](Function &F) {
+        return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
+      };
+      viewAllCFGs(M, LookupBFI, LookupBPI, /*isSimple=*/true);
       return false;
     }
 
-    void print(raw_ostream &OS, const Module* = nullptr) const override {}
+    void print(raw_ostream &OS, const Module * = nullptr) const override {}
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      ModulePass::getAnalysisUsage(AU);
+      AU.addRequired<BlockFrequencyInfoWrapperPass>();
+      AU.addRequired<BranchProbabilityInfoWrapperPass>();
       AU.setPreservesAll();
     }
+
   };
 }
 
@@ -76,43 +193,46 @@ char CFGOnlyViewerLegacyPass::ID = 0;
 INITIALIZE_PASS(CFGOnlyViewerLegacyPass, "view-cfg-only",
                 "View CFG of function (with no function bodies)", false, true)
 
-PreservedAnalyses CFGOnlyViewerPass::run(Function &F,
-                                         FunctionAnalysisManager &AM) {
-  F.viewCFGOnly();
+PreservedAnalyses CFGOnlyViewerPass::run(Module &M,
+                                         ModuleAnalysisManager &AM) {
+  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto LookupBFI = [&FAM](Function &F) {
+    return &FAM.getResult<BlockFrequencyAnalysis>(F);
+  };
+  auto LookupBPI = [&FAM](Function &F) {
+    return &FAM.getResult<BranchProbabilityAnalysis>(F);
+  };
+  viewAllCFGs(M, LookupBFI, LookupBPI, /*isSimple=*/true);
   return PreservedAnalyses::all();
 }
 
-static void writeCFGToDotFile(Function &F, bool CFGOnly = false) {
-  std::string Filename = ("cfg." + F.getName() + ".dot").str();
-  errs() << "Writing '" << Filename << "'...";
-
-  std::error_code EC;
-  raw_fd_ostream File(Filename, EC, sys::fs::F_Text);
-
-  if (!EC)
-    WriteGraph(File, (const Function*)&F, CFGOnly);
-  else
-    errs() << "  error opening file for writing!";
-  errs() << "\n";
-}
-
 namespace {
-  struct CFGPrinterLegacyPass : public FunctionPass {
+  struct CFGPrinterLegacyPass : public ModulePass {
     static char ID; // Pass identification, replacement for typeid
-    CFGPrinterLegacyPass() : FunctionPass(ID) {
+    CFGPrinterLegacyPass() : ModulePass(ID) {
       initializeCFGPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
     }
 
-    bool runOnFunction(Function &F) override {
-      writeCFGToDotFile(F);
+    bool runOnModule(Module &M) override {
+      auto LookupBFI = [this](Function &F) {
+        return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+      };
+      auto LookupBPI = [this](Function &F) {
+        return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
+      };
+      writeAllCFGsToDotFile(M, LookupBFI, LookupBPI, /*isSimple=*/false);
       return false;
     }
 
-    void print(raw_ostream &OS, const Module* = nullptr) const override {}
+    void print(raw_ostream &OS, const Module * = nullptr) const override {}
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      ModulePass::getAnalysisUsage(AU);
+      AU.addRequired<BlockFrequencyInfoWrapperPass>();
+      AU.addRequired<BranchProbabilityInfoWrapperPass>();
       AU.setPreservesAll();
     }
+
   };
 }
 
@@ -120,28 +240,46 @@ char CFGPrinterLegacyPass::ID = 0;
 INITIALIZE_PASS(CFGPrinterLegacyPass, "dot-cfg", "Print CFG of function to 'dot' file", 
                 false, true)
 
-PreservedAnalyses CFGPrinterPass::run(Function &F,
-                                      FunctionAnalysisManager &AM) {
-  writeCFGToDotFile(F);
+PreservedAnalyses CFGPrinterPass::run(Module &M,
+                                      ModuleAnalysisManager &AM) {
+  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto LookupBFI = [&FAM](Function &F) {
+    return &FAM.getResult<BlockFrequencyAnalysis>(F);
+  };
+  auto LookupBPI = [&FAM](Function &F) {
+    return &FAM.getResult<BranchProbabilityAnalysis>(F);
+  };
+  writeAllCFGsToDotFile(M, LookupBFI, LookupBPI, /*isSimple=*/false);
   return PreservedAnalyses::all();
 }
 
 namespace {
-  struct CFGOnlyPrinterLegacyPass : public FunctionPass {
+  struct CFGOnlyPrinterLegacyPass : public ModulePass {
     static char ID; // Pass identification, replacement for typeid
-    CFGOnlyPrinterLegacyPass() : FunctionPass(ID) {
+    CFGOnlyPrinterLegacyPass() : ModulePass(ID) {
       initializeCFGOnlyPrinterLegacyPassPass(*PassRegistry::getPassRegistry());
     }
 
-    bool runOnFunction(Function &F) override {
-      writeCFGToDotFile(F, /*CFGOnly=*/true);
+    bool runOnModule(Module &M) override {
+      auto LookupBFI = [this](Function &F) {
+        return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+      };
+      auto LookupBPI = [this](Function &F) {
+        return &this->getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI();
+      };
+      writeAllCFGsToDotFile(M, LookupBFI, LookupBPI, /*isSimple=*/true);
       return false;
     }
-    void print(raw_ostream &OS, const Module* = nullptr) const override {}
+
+    void print(raw_ostream &OS, const Module * = nullptr) const override {}
 
     void getAnalysisUsage(AnalysisUsage &AU) const override {
+      ModulePass::getAnalysisUsage(AU);
+      AU.addRequired<BlockFrequencyInfoWrapperPass>();
+      AU.addRequired<BranchProbabilityInfoWrapperPass>();
       AU.setPreservesAll();
     }
+
   };
 }
 
@@ -150,9 +288,16 @@ INITIALIZE_PASS(CFGOnlyPrinterLegacyPass, "dot-cfg-only",
    "Print CFG of function to 'dot' file (with no function bodies)",
    false, true)
 
-PreservedAnalyses CFGOnlyPrinterPass::run(Function &F,
-                                          FunctionAnalysisManager &AM) {
-  writeCFGToDotFile(F, /*CFGOnly=*/true);
+PreservedAnalyses CFGOnlyPrinterPass::run(Module &M,
+                                          ModuleAnalysisManager &AM) {
+  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+  auto LookupBFI = [&FAM](Function &F) {
+    return &FAM.getResult<BlockFrequencyAnalysis>(F);
+  };
+  auto LookupBPI = [&FAM](Function &F) {
+    return &FAM.getResult<BranchProbabilityAnalysis>(F);
+  };
+  writeAllCFGsToDotFile(M, LookupBFI, LookupBPI, /*isSimple=*/true);
   return PreservedAnalyses::all();
 }
 
@@ -162,7 +307,9 @@ PreservedAnalyses CFGOnlyPrinterPass::run(Function &F,
 /// being a 'dot' and 'gv' program in your path.
 ///
 void Function::viewCFG() const {
-  ViewGraph(this, "cfg" + getName());
+
+  CFGDOTInfo CFGInfo(this);
+  ViewGraph(&CFGInfo, "cfg" + getName());
 }
 
 /// viewCFGOnly - This function is meant for use from the debugger.  It works
@@ -171,14 +318,15 @@ void Function::viewCFG() const {
 /// this can make the graph smaller.
 ///
 void Function::viewCFGOnly() const {
-  ViewGraph(this, "cfg" + getName(), true);
+
+  CFGDOTInfo CFGInfo(this);
+  ViewGraph(&CFGInfo, "cfg" + getName(), true);
 }
 
-FunctionPass *llvm::createCFGPrinterLegacyPassPass () {
+ModulePass *llvm::createCFGPrinterLegacyPassPass() {
   return new CFGPrinterLegacyPass();
 }
 
-FunctionPass *llvm::createCFGOnlyPrinterLegacyPassPass () {
+ModulePass *llvm::createCFGOnlyPrinterLegacyPassPass() {
   return new CFGOnlyPrinterLegacyPass();
 }
-
index 8e8535a..6c0f408 100644 (file)
@@ -10,6 +10,7 @@ add_llvm_library(LLVMAnalysis
   BlockFrequencyInfoImpl.cpp
   BranchProbabilityInfo.cpp
   CFG.cpp
+  HeatUtils.cpp
   CFGPrinter.cpp
   CFLAndersAliasAnalysis.cpp
   CFLSteensAliasAnalysis.cpp
index e7017e7..a2e42bb 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Analysis/CallPrinter.h"
+
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/CallGraph.h"
 #include "llvm/Analysis/DOTGraphTraitsPass.h"
+#include "llvm/Analysis/HeatUtils.h"
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallSet.h"
 
 using namespace llvm;
 
+static cl::opt<bool> ShowHeatColors("callgraph-heat-colors", cl::init(true),
+                                    cl::Hidden,
+                                    cl::desc("Show heat colors in call-graph"));
+
+static cl::opt<bool>
+    EstimateEdgeWeight("callgraph-weights", cl::init(false), cl::Hidden,
+                       cl::desc("Show edges labeled with weights"));
+
+static cl::opt<bool>
+    FullCallGraph("callgraph-full", cl::init(false), cl::Hidden,
+                  cl::desc("Show full call-graph (including external nodes)"));
+
+static cl::opt<bool> UseCallCounter(
+    "callgraph-call-count", cl::init(false), cl::Hidden,
+    cl::desc("Use function's call counter as a heat metric. "
+             "The default is the function's maximum block frequency."));
+
 namespace llvm {
 
-template <> struct DOTGraphTraits<CallGraph *> : public DefaultDOTGraphTraits {
+class CallGraphDOTInfo {
+private:
+  Module *M;
+  CallGraph *CG;
+  DenseMap<const Function *, uint64_t> Freq;
+  uint64_t MaxFreq;
+  uint64_t MaxEdgeCount;
+
+public:
+  std::function<BlockFrequencyInfo *(Function &)> LookupBFI;
+
+  CallGraphDOTInfo(Module *M, CallGraph *CG,
+                   function_ref<BlockFrequencyInfo *(Function &)> LookupBFI)
+      : M(M), CG(CG), LookupBFI(LookupBFI) {
+    MaxFreq = 0;
+    MaxEdgeCount = 0;
+
+    for (Function &F : *M) {
+      Freq[&F] = 0;
+
+      if (FullCallGraph) {
+        for (User *U : F.users()) {
+          auto CS = CallSite(U);
+          if (!CS.getCaller()->isDeclaration()) {
+            uint64_t Counter = getNumOfCalls(CS, LookupBFI);
+            if (Counter > MaxEdgeCount) {
+              MaxEdgeCount = Counter;
+            }
+          }
+        }
+      }
+
+      if (F.isDeclaration())
+        continue;
+      uint64_t localMaxFreq = 0;
+      if (UseCallCounter) {
+        Function::ProfileCount EntryCount = F.getEntryCount();
+        if (EntryCount.hasValue())
+          localMaxFreq = EntryCount.getCount();
+      } else {
+        localMaxFreq = llvm::getMaxFreq(F, LookupBFI(F));
+      }
+      if (localMaxFreq >= MaxFreq)
+        MaxFreq = localMaxFreq;
+      Freq[&F] = localMaxFreq;
+
+      if (!FullCallGraph) {
+        for (Function &Callee : *M) {
+          uint64_t Counter = getNumOfCalls(F, Callee, LookupBFI);
+          if (Counter > MaxEdgeCount) {
+            MaxEdgeCount = Counter;
+          }
+        }
+      }
+    }
+    if (!FullCallGraph)
+      removeParallelEdges();
+  }
+
+  Module *getModule() const { return M; }
+  CallGraph *getCallGraph() const { return CG; }
+
+  uint64_t getFreq(const Function *F) { return Freq[F]; }
+
+  uint64_t getMaxFreq() { return MaxFreq; }
+
+  uint64_t getMaxEdgeCount() { return MaxEdgeCount; }
+
+private:
+  void removeParallelEdges() {
+    for (auto &I : (*CG)) {
+      CallGraphNode *Node = I.second.get();
+
+      bool FoundParallelEdge = true;
+      while (FoundParallelEdge) {
+        SmallSet<Function *, 16> Visited;
+        FoundParallelEdge = false;
+        for (auto CI = Node->begin(), CE = Node->end(); CI != CE; CI++) {
+          if (!Visited.count(CI->second->getFunction()))
+            Visited.insert(CI->second->getFunction());
+          else {
+            FoundParallelEdge = true;
+            Node->removeCallEdge(CI);
+            break;
+          }
+        }
+      }
+    }
+  }
+};
+
+template <>
+struct GraphTraits<CallGraphDOTInfo *>
+    : public GraphTraits<const CallGraphNode *> {
+  static NodeRef getEntryNode(CallGraphDOTInfo *CGInfo) {
+    // Start at the external node!
+    return CGInfo->getCallGraph()->getExternalCallingNode();
+  }
+
+  typedef std::pair<const Function *const, std::unique_ptr<CallGraphNode>>
+      PairTy;
+  static const CallGraphNode *CGGetValuePtr(const PairTy &P) {
+    return P.second.get();
+  }
+
+  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
+  typedef mapped_iterator<CallGraph::const_iterator, decltype(&CGGetValuePtr)>
+      nodes_iterator;
+
+  static nodes_iterator nodes_begin(CallGraphDOTInfo *CGInfo) {
+    return nodes_iterator(CGInfo->getCallGraph()->begin(), &CGGetValuePtr);
+  }
+  static nodes_iterator nodes_end(CallGraphDOTInfo *CGInfo) {
+    return nodes_iterator(CGInfo->getCallGraph()->end(), &CGGetValuePtr);
+  }
+};
+
+template <>
+struct DOTGraphTraits<CallGraphDOTInfo *> : public DefaultDOTGraphTraits {
+
+  SmallSet<User *, 16> VisitedCallSites;
+
   DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {}
 
-  static std::string getGraphName(CallGraph *Graph) { return "Call graph"; }
+  static std::string getGraphName(CallGraphDOTInfo *CGInfo) {
+    return "Call graph: " +
+           std::string(CGInfo->getModule()->getModuleIdentifier());
+  }
+
+  static bool isNodeHidden(const CallGraphNode *Node) {
+    if (FullCallGraph)
+      return false;
+
+    if (Node->getFunction())
+      return false;
+
+    return true;
+  }
+
+  std::string getNodeLabel(const CallGraphNode *Node,
+                           CallGraphDOTInfo *CGInfo) {
+    if (Node == CGInfo->getCallGraph()->getExternalCallingNode())
+      return "external caller";
+
+    if (Node == CGInfo->getCallGraph()->getCallsExternalNode())
+      return "external callee";
 
-  std::string getNodeLabel(CallGraphNode *Node, CallGraph *Graph) {
     if (Function *Func = Node->getFunction())
       return Func->getName();
 
     return "external node";
   }
-};
 
-struct AnalysisCallGraphWrapperPassTraits {
-  static CallGraph *getGraph(CallGraphWrapperPass *P) {
-    return &P->getCallGraph();
+  static const CallGraphNode *CGGetValuePtr(CallGraphNode::CallRecord P) {
+    return P.second;
+  }
+
+  // nodes_iterator/begin/end - Allow iteration over all nodes in the graph
+  typedef mapped_iterator<CallGraphNode::const_iterator,
+                          decltype(&CGGetValuePtr)>
+      nodes_iterator;
+
+  std::string getEdgeAttributes(const CallGraphNode *Node, nodes_iterator I,
+                                CallGraphDOTInfo *CGInfo) {
+    if (!EstimateEdgeWeight)
+      return "";
+
+    Function *Caller = Node->getFunction();
+    if (Caller == nullptr || Caller->isDeclaration())
+      return "";
+
+    Function *Callee = (*I)->getFunction();
+    if (Callee == nullptr)
+      return "";
+
+    uint64_t Counter = 0;
+    if (FullCallGraph) {
+      // looks for next call site between Caller and Callee
+      for (User *U : Callee->users()) {
+        auto CS = CallSite(U);
+        if (CS.getCaller() == Caller) {
+          if (VisitedCallSites.count(U))
+            continue;
+          VisitedCallSites.insert(U);
+          Counter = getNumOfCalls(CS, CGInfo->LookupBFI);
+          break;
+        }
+      }
+    } else {
+      Counter = getNumOfCalls(*Caller, *Callee, CGInfo->LookupBFI);
+    }
+
+    const unsigned MaxEdgeWidth = 3;
+
+    double Width =
+        1 + (MaxEdgeWidth - 1) * (double(Counter) / CGInfo->getMaxEdgeCount());
+    std::string Attrs = "label=\"" + std::to_string(Counter) +
+                        "\" penwidth=" + std::to_string(Width);
+
+    return Attrs;
+  }
+
+  std::string getNodeAttributes(const CallGraphNode *Node,
+                                CallGraphDOTInfo *CGInfo) {
+    Function *F = Node->getFunction();
+    if (F == nullptr || F->isDeclaration())
+      return "";
+
+    std::string attrs = "";
+    if (ShowHeatColors) {
+      uint64_t freq = CGInfo->getFreq(F);
+      std::string color = getHeatColor(freq, CGInfo->getMaxFreq());
+      std::string edgeColor = (freq <= (CGInfo->getMaxFreq() / 2))
+                                  ? getHeatColor(0)
+                                  : getHeatColor(1);
+
+      attrs = "color=\"" + edgeColor + "ff\", style=filled, fillcolor=\"" +
+              color + "80\"";
+    }
+    return attrs;
   }
 };
 
-} // end llvm namespace
+} // namespace llvm
 
 namespace {
 
-struct CallGraphViewer
-    : public DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *,
-                                        AnalysisCallGraphWrapperPassTraits> {
+// Viewer
+
+class CallGraphViewer : public ModulePass {
+public:
   static char ID;
+  CallGraphViewer() : ModulePass(ID) {}
 
-  CallGraphViewer()
-      : DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *,
-                                   AnalysisCallGraphWrapperPassTraits>(
-            "callgraph", ID) {
-    initializeCallGraphViewerPass(*PassRegistry::getPassRegistry());
-  }
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  bool runOnModule(Module &M) override;
 };
 
-struct CallGraphDOTPrinter : public DOTGraphTraitsModulePrinter<
-                              CallGraphWrapperPass, true, CallGraph *,
-                              AnalysisCallGraphWrapperPassTraits> {
+void CallGraphViewer::getAnalysisUsage(AnalysisUsage &AU) const {
+  ModulePass::getAnalysisUsage(AU);
+  AU.addRequired<BlockFrequencyInfoWrapperPass>();
+  AU.setPreservesAll();
+}
+
+bool CallGraphViewer::runOnModule(Module &M) {
+  auto LookupBFI = [this](Function &F) {
+    return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+  };
+
+  CallGraph CG(M);
+  CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
+
+  std::string Title =
+      DOTGraphTraits<CallGraphDOTInfo *>::getGraphName(&CFGInfo);
+  ViewGraph(&CFGInfo, "callgraph", true, Title);
+
+  return false;
+}
+
+// DOT Printer
+
+class CallGraphDOTPrinter : public ModulePass {
+public:
   static char ID;
+  CallGraphDOTPrinter() : ModulePass(ID) {}
 
-  CallGraphDOTPrinter()
-      : DOTGraphTraitsModulePrinter<CallGraphWrapperPass, true, CallGraph *,
-                                    AnalysisCallGraphWrapperPassTraits>(
-            "callgraph", ID) {
-    initializeCallGraphDOTPrinterPass(*PassRegistry::getPassRegistry());
-  }
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+  bool runOnModule(Module &M) override;
 };
 
+void CallGraphDOTPrinter::getAnalysisUsage(AnalysisUsage &AU) const {
+  ModulePass::getAnalysisUsage(AU);
+  AU.addRequired<BlockFrequencyInfoWrapperPass>();
+  AU.setPreservesAll();
+}
+
+bool CallGraphDOTPrinter::runOnModule(Module &M) {
+  auto LookupBFI = [this](Function &F) {
+    return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
+  };
+
+  std::string Filename =
+      (std::string(M.getModuleIdentifier()) + ".callgraph.dot");
+  errs() << "Writing '" << Filename << "'...";
+
+  std::error_code EC;
+  raw_fd_ostream File(Filename, EC, sys::fs::F_Text);
+
+  CallGraph CG(M);
+  CallGraphDOTInfo CFGInfo(&M, &CG, LookupBFI);
+
+  if (!EC)
+    WriteGraph(File, &CFGInfo);
+  else
+    errs() << "  error opening file for writing!";
+  errs() << "\n";
+
+  return false;
+}
+
 } // end anonymous namespace
 
 char CallGraphViewer::ID = 0;
index 8abc0e7..b1beb72 100644 (file)
@@ -38,13 +38,12 @@ struct DOTGraphTraits<DomTreeNode*> : public DefaultDOTGraphTraits {
     if (!BB)
       return "Post dominance root node";
 
-
     if (isSimple())
-      return DOTGraphTraits<const Function*>
-        ::getSimpleNodeLabel(BB, BB->getParent());
+      return DOTGraphTraits<CFGDOTInfo*>
+        ::getSimpleNodeLabel(BB, nullptr);
     else
-      return DOTGraphTraits<const Function*>
-        ::getCompleteNodeLabel(BB, BB->getParent());
+      return DOTGraphTraits<CFGDOTInfo*>
+        ::getCompleteNodeLabel(BB, nullptr);
   }
 };
 
diff --git a/lib/Analysis/HeatUtils.cpp b/lib/Analysis/HeatUtils.cpp
new file mode 100644 (file)
index 0000000..c328c48
--- /dev/null
@@ -0,0 +1,130 @@
+//===-- HeatUtils.cpp - Utility for printing heat colors --------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Utility for printing heat colors based on heuristics or profiling
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/HeatUtils.h"
+#include "llvm/IR/Instructions.h"
+
+namespace llvm {
+
+static const unsigned heatSize = 100;
+static const std::string heatPalette[heatSize] = {
+    "#3d50c3", "#4055c8", "#4358cb", "#465ecf", "#4961d2", "#4c66d6", "#4f69d9",
+    "#536edd", "#5572df", "#5977e3", "#5b7ae5", "#5f7fe8", "#6282ea", "#6687ed",
+    "#6a8bef", "#6c8ff1", "#7093f3", "#7396f5", "#779af7", "#7a9df8", "#7ea1fa",
+    "#81a4fb", "#85a8fc", "#88abfd", "#8caffe", "#8fb1fe", "#93b5fe", "#96b7ff",
+    "#9abbff", "#9ebeff", "#a1c0ff", "#a5c3fe", "#a7c5fe", "#abc8fd", "#aec9fc",
+    "#b2ccfb", "#b5cdfa", "#b9d0f9", "#bbd1f8", "#bfd3f6", "#c1d4f4", "#c5d6f2",
+    "#c7d7f0", "#cbd8ee", "#cedaeb", "#d1dae9", "#d4dbe6", "#d6dce4", "#d9dce1",
+    "#dbdcde", "#dedcdb", "#e0dbd8", "#e3d9d3", "#e5d8d1", "#e8d6cc", "#ead5c9",
+    "#ecd3c5", "#eed0c0", "#efcebd", "#f1ccb8", "#f2cab5", "#f3c7b1", "#f4c5ad",
+    "#f5c1a9", "#f6bfa6", "#f7bca1", "#f7b99e", "#f7b599", "#f7b396", "#f7af91",
+    "#f7ac8e", "#f7a889", "#f6a385", "#f5a081", "#f59c7d", "#f4987a", "#f39475",
+    "#f29072", "#f08b6e", "#ef886b", "#ed8366", "#ec7f63", "#e97a5f", "#e8765c",
+    "#e57058", "#e36c55", "#e16751", "#de614d", "#dc5d4a", "#d85646", "#d65244",
+    "#d24b40", "#d0473d", "#cc403a", "#ca3b37", "#c53334", "#c32e31", "#be242e",
+    "#bb1b2c", "#b70d28"};
+
+bool hasProfiling(const Module &M) {
+  for (auto &F : M) {
+    for (auto &BB : F) {
+      auto *TI = BB.getTerminator();
+      if (TI == nullptr)
+        continue;
+      if (TI->getMetadata(llvm::LLVMContext::MD_prof) != nullptr)
+        return true;
+    }
+  }
+  return false;
+}
+
+uint64_t getBlockFreq(const BasicBlock *BB, const BlockFrequencyInfo *BFI,
+                      bool useHeuristic) {
+  uint64_t freqVal = 0;
+  if (!useHeuristic) {
+    Optional<uint64_t> freq = BFI->getBlockProfileCount(BB);
+    if (freq.hasValue())
+      freqVal = freq.getValue();
+  } else {
+    freqVal = BFI->getBlockFreq(BB).getFrequency();
+  }
+  return freqVal;
+}
+
+uint64_t getNumOfCalls(CallSite &CS,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic) {
+  if (CS.getInstruction()==nullptr) return 0;
+  if (CS.getInstruction()->getParent()==nullptr) return 0;
+  BasicBlock *BB = CS.getInstruction()->getParent();
+  return getBlockFreq(BB, LookupBFI(*CS.getCaller()));
+}
+
+uint64_t getNumOfCalls(Function &callerFunction, Function &calledFunction,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic) {
+  uint64_t counter = 0;
+  for (User *U : calledFunction.users()) {
+    if (isa<CallInst>(U)) {
+      auto CS = CallSite(U);
+      if (CS.getCaller() == (&callerFunction)) {
+        counter += getNumOfCalls(CS, LookupBFI);
+      }
+    }
+  }
+  return counter;
+}
+
+uint64_t getMaxFreq(const Function &F, const BlockFrequencyInfo *BFI,
+                    bool useHeuristic) {
+  uint64_t maxFreq = 0;
+  for (const BasicBlock &BB : F) {
+    uint64_t freqVal = getBlockFreq(&BB, BFI, useHeuristic);
+    if (freqVal >= maxFreq)
+      maxFreq = freqVal;
+  }
+  return maxFreq;
+}
+
+uint64_t getMaxFreq(Module &M,
+                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                    bool useHeuristic) {
+  uint64_t maxFreq = 0;
+  for (auto &F : M) {
+    if (F.isDeclaration())
+      continue;
+    uint64_t localMaxFreq = getMaxFreq(F, LookupBFI(F), useHeuristic);
+    if (localMaxFreq >= maxFreq)
+      maxFreq = localMaxFreq;
+  }
+  return maxFreq;
+}
+
+std::string getHeatColor(uint64_t freq, uint64_t maxFreq) {
+  if (freq > maxFreq)
+    freq = maxFreq;
+  unsigned colorId =
+      unsigned(round((double(freq) / maxFreq) * (heatSize - 1.0)));
+  return heatPalette[colorId];
+}
+
+std::string getHeatColor(double percent) {
+  if (percent > 1.0)
+    percent = 1.0;
+  if (percent < 0.0)
+    percent = 0.0;
+  unsigned colorId = unsigned(round(percent * (heatSize - 1.0)));
+  return heatPalette[colorId];
+}
+
+} // namespace llvm
index 5986b8c..8fe4507 100644 (file)
@@ -47,11 +47,11 @@ struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits {
       BasicBlock *BB = Node->getNodeAs<BasicBlock>();
 
       if (isSimple())
-        return DOTGraphTraits<const Function*>
-          ::getSimpleNodeLabel(BB, BB->getParent());
+        return DOTGraphTraits<CFGDOTInfo*>
+          ::getSimpleNodeLabel(BB, nullptr);
       else
-        return DOTGraphTraits<const Function*>
-          ::getCompleteNodeLabel(BB, BB->getParent());
+        return DOTGraphTraits<CFGDOTInfo*>
+          ::getCompleteNodeLabel(BB, nullptr);
     }
 
     return "Not implemented";
index 0a8d40a..8d7dac5 100644 (file)
@@ -43,6 +43,8 @@ MODULE_PASS("called-value-propagation", CalledValuePropagationPass())
 MODULE_PASS("constmerge", ConstantMergePass())
 MODULE_PASS("cross-dso-cfi", CrossDSOCFIPass())
 MODULE_PASS("deadargelim", DeadArgumentEliminationPass())
+MODULE_PASS("dot-cfg", CFGPrinterPass())
+MODULE_PASS("dot-cfg-only", CFGOnlyPrinterPass())
 MODULE_PASS("elim-avail-extern", EliminateAvailableExternallyPass())
 MODULE_PASS("forceattrs", ForceFunctionAttrsPass())
 MODULE_PASS("function-import", FunctionImportPass())
@@ -76,6 +78,8 @@ MODULE_PASS("strip-dead-prototypes", StripDeadPrototypesPass())
 MODULE_PASS("synthetic-counts-propagation", SyntheticCountsPropagation())
 MODULE_PASS("wholeprogramdevirt", WholeProgramDevirtPass())
 MODULE_PASS("verify", VerifierPass())
+MODULE_PASS("view-cfg", CFGViewerPass())
+MODULE_PASS("view-cfg-only", CFGOnlyViewerPass())
 #undef MODULE_PASS
 
 #ifndef CGSCC_ANALYSIS
@@ -151,8 +155,6 @@ FUNCTION_PASS("correlated-propagation", CorrelatedValuePropagationPass())
 FUNCTION_PASS("dce", DCEPass())
 FUNCTION_PASS("div-rem-pairs", DivRemPairsPass())
 FUNCTION_PASS("dse", DSEPass())
-FUNCTION_PASS("dot-cfg", CFGPrinterPass())
-FUNCTION_PASS("dot-cfg-only", CFGOnlyPrinterPass())
 FUNCTION_PASS("early-cse", EarlyCSEPass(/*UseMemorySSA=*/false))
 FUNCTION_PASS("early-cse-memssa", EarlyCSEPass(/*UseMemorySSA=*/true))
 FUNCTION_PASS("ee-instrument", EntryExitInstrumenterPass(/*PostInlining=*/false))
@@ -214,8 +216,6 @@ FUNCTION_PASS("verify<domtree>", DominatorTreeVerifierPass())
 FUNCTION_PASS("verify<loops>", LoopVerifierPass())
 FUNCTION_PASS("verify<memoryssa>", MemorySSAVerifierPass())
 FUNCTION_PASS("verify<regions>", RegionInfoVerifierPass())
-FUNCTION_PASS("view-cfg", CFGViewerPass())
-FUNCTION_PASS("view-cfg-only", CFGOnlyViewerPass())
 #undef FUNCTION_PASS
 
 #ifndef LOOP_ANALYSIS
index 994d0d2..4841d16 100644 (file)
@@ -896,7 +896,7 @@ bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const {
 
 #ifndef NDEBUG
 static std::string getBlockName(const BasicBlock *B) {
-  return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr);
+  return DOTGraphTraits<CFGDOTInfo *>::getSimpleNodeLabel(B, nullptr);
 }
 #endif
 
diff --git a/llvm/Analysis/HeatUtils.h b/llvm/Analysis/HeatUtils.h
new file mode 100644 (file)
index 0000000..8cb03b9
--- /dev/null
@@ -0,0 +1,54 @@
+//===-- HeatUtils.h - Utility for printing heat colors ----------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Utility for printing heat colors based on heuristics or profiling
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_HEATUTILS_H
+#define LLVM_ANALYSIS_HEATUTILS_H
+
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/CallSite.h"
+
+#include <string>
+
+namespace llvm {
+
+bool hasProfiling(const Module &M);
+
+uint64_t getBlockFreq(const BasicBlock *BB, const BlockFrequencyInfo *BFI,
+                      bool useHeuristic = true);
+
+uint64_t getNumOfCalls(Function &callerFunction, Function &calledFunction,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic = true);
+
+uint64_t getNumOfCalls(CallSite &callsite,
+                       function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                       bool useHeuristic = true);
+
+uint64_t getMaxFreq(const Function &F, const BlockFrequencyInfo *BFI,
+                    bool useHeuristic = true);
+
+uint64_t getMaxFreq(Module &M,
+                    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI,
+                    bool useHeuristic = true);
+
+std::string getHeatColor(uint64_t freq, uint64_t maxFreq);
+
+std::string getHeatColor(double percent);
+
+} // namespace llvm
+
+#endif
index 386e444..6964f28 100644 (file)
@@ -1,5 +1,13 @@
 ;RUN: opt < %s -analyze -dot-cfg-only 2>/dev/null
 ;RUN: opt < %s -analyze -passes=dot-cfg-only 2>/dev/null
+;RUN: opt < %s -analyze -dot-cfg-only \
+;RUN:          -cfg-heat-colors=true -cfg-weights=true 2>/dev/null
+;RUN: opt < %s -analyze -dot-cfg-only \
+;RUN:          -cfg-heat-colors=false -cfg-weights=false 2>/dev/null
+;RUN: opt < %s -analyze -dot-cfg \
+;RUN:          -cfg-heat-colors=true -cfg-weights=true 2>/dev/null
+;RUN: opt < %s -analyze -dot-cfg \
+;RUN:          -cfg-heat-colors=false -cfg-weights=false 2>/dev/null
 ;PR 1497
 
 define void @foo() {