1 //===- unittests/IR/PassBuilderCallbacksTest.cpp - PB Callback Tests --===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 #include <gmock/gmock.h>
11 #include <gtest/gtest.h>
12 #include <llvm/Analysis/CGSCCPassManager.h>
13 #include <llvm/Analysis/LoopAnalysisManager.h>
14 #include <llvm/AsmParser/Parser.h>
15 #include <llvm/IR/LLVMContext.h>
16 #include <llvm/IR/PassManager.h>
17 #include <llvm/Passes/PassBuilder.h>
18 #include <llvm/Support/SourceMgr.h>
19 #include <llvm/Transforms/Scalar/LoopPassManager.h>
24 /// Provide an ostream operator for StringRef.
26 /// For convenience we provide a custom matcher below for IRUnit's and analysis
27 /// result's getName functions, which most of the time returns a StringRef. The
28 /// matcher makes use of this operator.
29 static std::ostream &operator<<(std::ostream &O, StringRef S) {
35 using testing::DoDefault;
36 using testing::Return;
37 using testing::Expectation;
38 using testing::Invoke;
39 using testing::WithArgs;
42 /// \brief A CRTP base for analysis mock handles
44 /// This class reconciles mocking with the value semantics implementation of the
45 /// AnalysisManager. Analysis mock handles should derive from this class and
46 /// call \c setDefault() in their constroctur for wiring up the defaults defined
47 /// by this base with their mock run() and invalidate() implementations.
48 template <typename DerivedT, typename IRUnitT,
49 typename AnalysisManagerT = AnalysisManager<IRUnitT>,
50 typename... ExtraArgTs>
51 class MockAnalysisHandleBase {
53 class Analysis : public AnalysisInfoMixin<Analysis> {
54 friend AnalysisInfoMixin<Analysis>;
55 friend MockAnalysisHandleBase;
56 static AnalysisKey Key;
60 Analysis(DerivedT &Handle) : Handle(&Handle) {
61 static_assert(std::is_base_of<MockAnalysisHandleBase, DerivedT>::value,
62 "Must pass the derived type to this template!");
67 friend MockAnalysisHandleBase;
71 Result(DerivedT &Handle) : Handle(&Handle) {}
74 // Forward invalidation events to the mock handle.
75 bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA,
76 typename AnalysisManagerT::Invalidator &Inv) {
77 return Handle->invalidate(IR, PA, Inv);
81 Result run(IRUnitT &IR, AnalysisManagerT &AM, ExtraArgTs... ExtraArgs) {
82 return Handle->run(IR, AM, ExtraArgs...);
86 Analysis getAnalysis() { return Analysis(static_cast<DerivedT &>(*this)); }
87 typename Analysis::Result getResult() {
88 return typename Analysis::Result(static_cast<DerivedT &>(*this));
92 // FIXME: MSVC seems unable to handle a lambda argument to Invoke from within
93 // the template, so we use a boring static function.
94 static bool invalidateCallback(IRUnitT &IR, const PreservedAnalyses &PA,
95 typename AnalysisManagerT::Invalidator &Inv) {
96 auto PAC = PA.template getChecker<Analysis>();
97 return !PAC.preserved() &&
98 !PAC.template preservedSet<AllAnalysesOn<IRUnitT>>();
101 /// Derived classes should call this in their constructor to set up default
102 /// mock actions. (We can't do this in our constructor because this has to
103 /// run after the DerivedT is constructed.)
105 ON_CALL(static_cast<DerivedT &>(*this),
106 run(_, _, testing::Matcher<ExtraArgTs>(_)...))
107 .WillByDefault(Return(this->getResult()));
108 ON_CALL(static_cast<DerivedT &>(*this), invalidate(_, _, _))
109 .WillByDefault(Invoke(&invalidateCallback));
113 /// \brief A CRTP base for pass mock handles
115 /// This class reconciles mocking with the value semantics implementation of the
116 /// PassManager. Pass mock handles should derive from this class and
117 /// call \c setDefault() in their constroctur for wiring up the defaults defined
118 /// by this base with their mock run() and invalidate() implementations.
119 template <typename DerivedT, typename IRUnitT, typename AnalysisManagerT,
120 typename... ExtraArgTs>
121 AnalysisKey MockAnalysisHandleBase<DerivedT, IRUnitT, AnalysisManagerT,
122 ExtraArgTs...>::Analysis::Key;
124 template <typename DerivedT, typename IRUnitT,
125 typename AnalysisManagerT = AnalysisManager<IRUnitT>,
126 typename... ExtraArgTs>
127 class MockPassHandleBase {
129 class Pass : public PassInfoMixin<Pass> {
130 friend MockPassHandleBase;
134 Pass(DerivedT &Handle) : Handle(&Handle) {
135 static_assert(std::is_base_of<MockPassHandleBase, DerivedT>::value,
136 "Must pass the derived type to this template!");
140 PreservedAnalyses run(IRUnitT &IR, AnalysisManagerT &AM,
141 ExtraArgTs... ExtraArgs) {
142 return Handle->run(IR, AM, ExtraArgs...);
146 Pass getPass() { return Pass(static_cast<DerivedT &>(*this)); }
149 /// Derived classes should call this in their constructor to set up default
150 /// mock actions. (We can't do this in our constructor because this has to
151 /// run after the DerivedT is constructed.)
153 ON_CALL(static_cast<DerivedT &>(*this),
154 run(_, _, testing::Matcher<ExtraArgTs>(_)...))
155 .WillByDefault(Return(PreservedAnalyses::all()));
159 /// Mock handles for passes for the IRUnits Module, CGSCC, Function, Loop.
160 /// These handles define the appropriate run() mock interface for the respective
162 template <typename IRUnitT> struct MockPassHandle;
164 struct MockPassHandle<Loop>
165 : MockPassHandleBase<MockPassHandle<Loop>, Loop, LoopAnalysisManager,
166 LoopStandardAnalysisResults &, LPMUpdater &> {
168 PreservedAnalyses(Loop &, LoopAnalysisManager &,
169 LoopStandardAnalysisResults &, LPMUpdater &));
170 MockPassHandle() { setDefaults(); }
174 struct MockPassHandle<Function>
175 : MockPassHandleBase<MockPassHandle<Function>, Function> {
176 MOCK_METHOD2(run, PreservedAnalyses(Function &, FunctionAnalysisManager &));
178 MockPassHandle() { setDefaults(); }
182 struct MockPassHandle<LazyCallGraph::SCC>
183 : MockPassHandleBase<MockPassHandle<LazyCallGraph::SCC>, LazyCallGraph::SCC,
184 CGSCCAnalysisManager, LazyCallGraph &,
185 CGSCCUpdateResult &> {
187 PreservedAnalyses(LazyCallGraph::SCC &, CGSCCAnalysisManager &,
188 LazyCallGraph &G, CGSCCUpdateResult &UR));
190 MockPassHandle() { setDefaults(); }
194 struct MockPassHandle<Module>
195 : MockPassHandleBase<MockPassHandle<Module>, Module> {
196 MOCK_METHOD2(run, PreservedAnalyses(Module &, ModuleAnalysisManager &));
198 MockPassHandle() { setDefaults(); }
201 /// Mock handles for analyses for the IRUnits Module, CGSCC, Function, Loop.
202 /// These handles define the appropriate run() and invalidate() mock interfaces
203 /// for the respective IRUnit type.
204 template <typename IRUnitT> struct MockAnalysisHandle;
206 struct MockAnalysisHandle<Loop>
207 : MockAnalysisHandleBase<MockAnalysisHandle<Loop>, Loop,
209 LoopStandardAnalysisResults &> {
211 MOCK_METHOD3_T(run, typename Analysis::Result(Loop &, LoopAnalysisManager &,
212 LoopStandardAnalysisResults &));
214 MOCK_METHOD3_T(invalidate, bool(Loop &, const PreservedAnalyses &,
215 LoopAnalysisManager::Invalidator &));
217 MockAnalysisHandle<Loop>() { this->setDefaults(); }
221 struct MockAnalysisHandle<Function>
222 : MockAnalysisHandleBase<MockAnalysisHandle<Function>, Function> {
223 MOCK_METHOD2(run, Analysis::Result(Function &, FunctionAnalysisManager &));
225 MOCK_METHOD3(invalidate, bool(Function &, const PreservedAnalyses &,
226 FunctionAnalysisManager::Invalidator &));
228 MockAnalysisHandle<Function>() { setDefaults(); }
232 struct MockAnalysisHandle<LazyCallGraph::SCC>
233 : MockAnalysisHandleBase<MockAnalysisHandle<LazyCallGraph::SCC>,
234 LazyCallGraph::SCC, CGSCCAnalysisManager,
236 MOCK_METHOD3(run, Analysis::Result(LazyCallGraph::SCC &,
237 CGSCCAnalysisManager &, LazyCallGraph &));
239 MOCK_METHOD3(invalidate, bool(LazyCallGraph::SCC &, const PreservedAnalyses &,
240 CGSCCAnalysisManager::Invalidator &));
242 MockAnalysisHandle<LazyCallGraph::SCC>() { setDefaults(); }
246 struct MockAnalysisHandle<Module>
247 : MockAnalysisHandleBase<MockAnalysisHandle<Module>, Module> {
248 MOCK_METHOD2(run, Analysis::Result(Module &, ModuleAnalysisManager &));
250 MOCK_METHOD3(invalidate, bool(Module &, const PreservedAnalyses &,
251 ModuleAnalysisManager::Invalidator &));
253 MockAnalysisHandle<Module>() { setDefaults(); }
256 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
258 return parseAssemblyString(IR, Err, C);
261 template <typename PassManagerT> class PassBuilderCallbacksTest;
263 /// This test fixture is shared between all the actual tests below and
264 /// takes care of setting up appropriate defaults.
266 /// The template specialization serves to extract the IRUnit and AM types from
267 /// the given PassManagerT.
268 template <typename TestIRUnitT, typename... ExtraPassArgTs,
269 typename... ExtraAnalysisArgTs>
270 class PassBuilderCallbacksTest<PassManager<
271 TestIRUnitT, AnalysisManager<TestIRUnitT, ExtraAnalysisArgTs...>,
272 ExtraPassArgTs...>> : public testing::Test {
274 using IRUnitT = TestIRUnitT;
275 using AnalysisManagerT = AnalysisManager<TestIRUnitT, ExtraAnalysisArgTs...>;
277 PassManager<TestIRUnitT, AnalysisManagerT, ExtraPassArgTs...>;
278 using AnalysisT = typename MockAnalysisHandle<IRUnitT>::Analysis;
281 std::unique_ptr<Module> M;
284 ModulePassManager PM;
285 LoopAnalysisManager LAM;
286 FunctionAnalysisManager FAM;
287 CGSCCAnalysisManager CGAM;
288 ModuleAnalysisManager AM;
290 MockPassHandle<IRUnitT> PassHandle;
291 MockAnalysisHandle<IRUnitT> AnalysisHandle;
293 static PreservedAnalyses getAnalysisResult(IRUnitT &U, AnalysisManagerT &AM,
294 ExtraAnalysisArgTs &&... Args) {
295 (void)AM.template getResult<AnalysisT>(
296 U, std::forward<ExtraAnalysisArgTs>(Args)...);
297 return PreservedAnalyses::all();
300 PassBuilderCallbacksTest()
302 "declare void @bar()\n"
303 "define void @foo(i32 %n) {\n"
307 " %iv = phi i32 [ 0, %entry ], [ %iv.next, %loop ]\n"
308 " %iv.next = add i32 %iv, 1\n"
309 " tail call void @bar()\n"
310 " %cmp = icmp eq i32 %iv, %n\n"
311 " br i1 %cmp, label %exit, label %loop\n"
315 PM(true), LAM(true), FAM(true), CGAM(true), AM(true) {
317 /// Register a callback for analysis registration.
319 /// The callback is a function taking a reference to an AnalyisManager
320 /// object. When called, the callee gets to register its own analyses with
321 /// this PassBuilder instance.
322 PB.registerAnalysisRegistrationCallback([this](AnalysisManagerT &AM) {
323 // Register our mock analysis
324 AM.registerPass([this] { return AnalysisHandle.getAnalysis(); });
327 /// Register a callback for pipeline parsing.
329 /// During parsing of a textual pipeline, the PassBuilder will call these
330 /// callbacks for each encountered pass name that it does not know. This
331 /// includes both simple pass names as well as names of sub-pipelines. In
332 /// the latter case, the InnerPipeline is not empty.
333 PB.registerPipelineParsingCallback(
334 [this](StringRef Name, PassManagerT &PM,
335 ArrayRef<PassBuilder::PipelineElement> InnerPipeline) {
336 /// Handle parsing of the names of analysis utilities such as
337 /// require<test-analysis> and invalidate<test-analysis> for our
338 /// analysis mock handle
339 if (parseAnalysisUtilityPasses<AnalysisT>("test-analysis", Name, PM))
342 /// Parse the name of our pass mock handle
343 if (Name == "test-transform") {
344 PM.addPass(PassHandle.getPass());
350 /// Register builtin analyses and cross-register the analysis proxies
351 PB.registerModuleAnalyses(AM);
352 PB.registerCGSCCAnalyses(CGAM);
353 PB.registerFunctionAnalyses(FAM);
354 PB.registerLoopAnalyses(LAM);
355 PB.crossRegisterProxies(LAM, FAM, CGAM, AM);
359 /// Define a custom matcher for objects which support a 'getName' method.
361 /// LLVM often has IR objects or analysis objects which expose a name
362 /// and in tests it is convenient to match these by name for readability.
363 /// Usually, this name is either a StringRef or a plain std::string. This
364 /// matcher supports any type exposing a getName() method of this form whose
365 /// return value is compatible with an std::ostream. For StringRef, this uses
366 /// the shift operator defined above.
368 /// It should be used as:
370 /// HasName("my_function")
372 /// No namespace or other qualification is required.
373 MATCHER_P(HasName, Name, "") {
374 *result_listener << "has name '" << arg.getName() << "'";
375 return Name == arg.getName();
378 using ModuleCallbacksTest = PassBuilderCallbacksTest<ModulePassManager>;
379 using CGSCCCallbacksTest = PassBuilderCallbacksTest<CGSCCPassManager>;
380 using FunctionCallbacksTest = PassBuilderCallbacksTest<FunctionPassManager>;
381 using LoopCallbacksTest = PassBuilderCallbacksTest<LoopPassManager>;
383 /// Test parsing of the name of our mock pass for all IRUnits.
385 /// The pass should by default run our mock analysis and then preserve it.
386 TEST_F(ModuleCallbacksTest, Passes) {
387 EXPECT_CALL(AnalysisHandle, run(HasName("<string>"), _));
388 EXPECT_CALL(PassHandle, run(HasName("<string>"), _))
389 .WillOnce(Invoke(getAnalysisResult));
391 StringRef PipelineText = "test-transform";
392 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
393 << "Pipeline was: " << PipelineText;
397 TEST_F(FunctionCallbacksTest, Passes) {
398 EXPECT_CALL(AnalysisHandle, run(HasName("foo"), _));
399 EXPECT_CALL(PassHandle, run(HasName("foo"), _))
400 .WillOnce(Invoke(getAnalysisResult));
402 StringRef PipelineText = "test-transform";
403 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
404 << "Pipeline was: " << PipelineText;
408 TEST_F(LoopCallbacksTest, Passes) {
409 EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _));
410 EXPECT_CALL(PassHandle, run(HasName("loop"), _, _, _))
411 .WillOnce(WithArgs<0, 1, 2>(Invoke(getAnalysisResult)));
413 StringRef PipelineText = "test-transform";
414 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
415 << "Pipeline was: " << PipelineText;
419 TEST_F(CGSCCCallbacksTest, Passes) {
420 EXPECT_CALL(AnalysisHandle, run(HasName("(foo)"), _, _));
421 EXPECT_CALL(PassHandle, run(HasName("(foo)"), _, _, _))
422 .WillOnce(WithArgs<0, 1, 2>(Invoke(getAnalysisResult)));
424 StringRef PipelineText = "test-transform";
425 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
426 << "Pipeline was: " << PipelineText;
430 /// Test parsing of the names of analysis utilities for our mock analysis
433 /// We first require<>, then invalidate<> it, expecting the analysis to be run
434 /// once and subsequently invalidated.
435 TEST_F(ModuleCallbacksTest, AnalysisUtilities) {
436 EXPECT_CALL(AnalysisHandle, run(HasName("<string>"), _));
437 EXPECT_CALL(AnalysisHandle, invalidate(HasName("<string>"), _, _));
439 StringRef PipelineText = "require<test-analysis>,invalidate<test-analysis>";
440 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
441 << "Pipeline was: " << PipelineText;
445 TEST_F(CGSCCCallbacksTest, PassUtilities) {
446 EXPECT_CALL(AnalysisHandle, run(HasName("(foo)"), _, _));
447 EXPECT_CALL(AnalysisHandle, invalidate(HasName("(foo)"), _, _));
449 StringRef PipelineText = "require<test-analysis>,invalidate<test-analysis>";
450 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
451 << "Pipeline was: " << PipelineText;
455 TEST_F(FunctionCallbacksTest, AnalysisUtilities) {
456 EXPECT_CALL(AnalysisHandle, run(HasName("foo"), _));
457 EXPECT_CALL(AnalysisHandle, invalidate(HasName("foo"), _, _));
459 StringRef PipelineText = "require<test-analysis>,invalidate<test-analysis>";
460 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
461 << "Pipeline was: " << PipelineText;
465 TEST_F(LoopCallbacksTest, PassUtilities) {
466 EXPECT_CALL(AnalysisHandle, run(HasName("loop"), _, _));
467 EXPECT_CALL(AnalysisHandle, invalidate(HasName("loop"), _, _));
469 StringRef PipelineText = "require<test-analysis>,invalidate<test-analysis>";
471 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
472 << "Pipeline was: " << PipelineText;
476 /// Test parsing of the top-level pipeline.
478 /// The ParseTopLevelPipeline callback takes over parsing of the entire pipeline
479 /// from PassBuilder if it encounters an unknown pipeline entry at the top level
480 /// (i.e., the first entry on the pipeline).
481 /// This test parses a pipeline named 'another-pipeline', whose only elements
482 /// may be the test-transform pass or the analysis utilities
483 TEST_F(ModuleCallbacksTest, ParseTopLevelPipeline) {
484 PB.registerParseTopLevelPipelineCallback([this](
485 ModulePassManager &MPM, ArrayRef<PassBuilder::PipelineElement> Pipeline,
486 bool VerifyEachPass, bool DebugLogging) {
487 auto &FirstName = Pipeline.front().Name;
488 auto &InnerPipeline = Pipeline.front().InnerPipeline;
489 if (FirstName == "another-pipeline") {
490 for (auto &E : InnerPipeline) {
491 if (parseAnalysisUtilityPasses<AnalysisT>("test-analysis", E.Name, PM))
494 if (E.Name == "test-transform") {
495 PM.addPass(PassHandle.getPass());
504 EXPECT_CALL(AnalysisHandle, run(HasName("<string>"), _));
505 EXPECT_CALL(PassHandle, run(HasName("<string>"), _))
506 .WillOnce(Invoke(getAnalysisResult));
507 EXPECT_CALL(AnalysisHandle, invalidate(HasName("<string>"), _, _));
509 StringRef PipelineText =
510 "another-pipeline(test-transform,invalidate<test-analysis>)";
511 ASSERT_TRUE(PB.parsePassPipeline(PM, PipelineText, true))
512 << "Pipeline was: " << PipelineText;
515 /// Test the negative case
516 PipelineText = "another-pipeline(instcombine)";
517 ASSERT_FALSE(PB.parsePassPipeline(PM, PipelineText, true))
518 << "Pipeline was: " << PipelineText;
520 } // end anonymous namespace