From f8bd622d12737eedbedff01d28d3c5366a40e71b Mon Sep 17 00:00:00 2001 From: PaperChalice Date: Sun, 24 Dec 2023 15:19:12 +0800 Subject: [PATCH 1/2] [CodeGen] Let `PassBuilder` support machine passes `PassBuilder` would be a better place to parse MIR pipeline. We can reuse the code to support parsing pass with parameters and targets can reuse `registerPassBuilderCallbacks` to register the target specific passes. `PassBuilder` also has ability to check whether a Pass is a machine pass, so it can replace part of the work of `LLVMTargetMachine::getPassNameFromLegacyName`. --- llvm/include/llvm/Passes/PassBuilder.h | 39 ++ llvm/lib/Passes/PassBuilder.cpp | 82 +++ llvm/unittests/MIR/CMakeLists.txt | 2 + .../MIR/PassBuilderCallbacksTest.cpp | 499 ++++++++++++++++++ 4 files changed, 622 insertions(+) create mode 100644 llvm/unittests/MIR/PassBuilderCallbacksTest.cpp diff --git a/llvm/include/llvm/Passes/PassBuilder.h b/llvm/include/llvm/Passes/PassBuilder.h index 61417431f8a8f..6b0ad7e7d11a5 100644 --- a/llvm/include/llvm/Passes/PassBuilder.h +++ b/llvm/include/llvm/Passes/PassBuilder.h @@ -16,6 +16,7 @@ #define LLVM_PASSES_PASSBUILDER_H #include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/CodeGen/MachinePassManager.h" #include "llvm/IR/PassManager.h" #include "llvm/Passes/OptimizationLevel.h" #include "llvm/Support/Error.h" @@ -165,6 +166,14 @@ class PassBuilder { /// additional analyses. void registerLoopAnalyses(LoopAnalysisManager &LAM); + /// Registers all available machine function analysis passes. + /// + /// This is an interface that can be used to populate a \c + /// MachineFunctionAnalysisManager with all registered function analyses. + /// Callers can still manually register any additional analyses. Callers can + /// also pre-register analyses and this will not override those. + void registerMachineFunctionAnalyses(MachineFunctionAnalysisManager &MFAM); + /// Construct the core LLVM function canonicalization and simplification /// pipeline. /// @@ -352,6 +361,18 @@ class PassBuilder { Error parsePassPipeline(LoopPassManager &LPM, StringRef PipelineText); /// @}} + /// Parse a textual MIR pipeline into the provided \c MachineFunctionPass + /// manager. + /// The format of the textual machine pipeline is a comma separated list of + /// machine pass names: + /// + /// machine-funciton-pass,machine-module-pass,... + /// + /// There is no need to specify the pass nesting, and this function + /// currently cannot handle the pass nesting. + Error parsePassPipeline(MachineFunctionPassManager &MFPM, + StringRef PipelineText); + /// Parse a textual alias analysis pipeline into the provided AA manager. /// /// The format of the textual AA pipeline is a comma separated list of AA @@ -520,6 +541,10 @@ class PassBuilder { const std::function &C) { ModuleAnalysisRegistrationCallbacks.push_back(C); } + void registerAnalysisRegistrationCallback( + const std::function &C) { + MachineFunctionAnalysisRegistrationCallbacks.push_back(C); + } /// @}} /// {{@ Register pipeline parsing callbacks with this pass builder instance. @@ -546,6 +571,11 @@ class PassBuilder { ArrayRef)> &C) { ModulePipelineParsingCallbacks.push_back(C); } + void registerPipelineParsingCallback( + const std::function + &C) { + MachinePipelineParsingCallbacks.push_back(C); + } /// @}} /// Register a callback for a top-level pipeline entry. @@ -616,8 +646,12 @@ class PassBuilder { Error parseCGSCCPass(CGSCCPassManager &CGPM, const PipelineElement &E); Error parseFunctionPass(FunctionPassManager &FPM, const PipelineElement &E); Error parseLoopPass(LoopPassManager &LPM, const PipelineElement &E); + Error parseMachinePass(MachineFunctionPassManager &MFPM, + const PipelineElement &E); bool parseAAPassName(AAManager &AA, StringRef Name); + Error parseMachinePassPipeline(MachineFunctionPassManager &MFPM, + ArrayRef Pipeline); Error parseLoopPassPipeline(LoopPassManager &LPM, ArrayRef Pipeline); Error parseFunctionPassPipeline(FunctionPassManager &FPM, @@ -699,6 +733,11 @@ class PassBuilder { // AA callbacks SmallVector, 2> AAParsingCallbacks; + // Machine pass callbackcs + SmallVector, 2> + MachineFunctionAnalysisRegistrationCallbacks; + SmallVector, 2> + MachinePipelineParsingCallbacks; }; /// This utility template takes care of adding require<> and invalidate<> diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index bfc97d5464c04..e6acc7e021e01 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -488,6 +488,12 @@ PassBuilder::PassBuilder(TargetMachine *TM, PipelineTuningOptions PTO, #define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ PIC->addClassToPassName(decltype(CREATE_PASS)::name(), NAME); #include "PassRegistry.def" + +#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \ + PIC->addClassToPassName(PASS_NAME::name(), NAME); +#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \ + PIC->addClassToPassName(PASS_NAME::name(), NAME); +#include "llvm/CodeGen/MachinePassRegistry.def" } } @@ -523,6 +529,17 @@ void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) { C(FAM); } +void PassBuilder::registerMachineFunctionAnalyses( + MachineFunctionAnalysisManager &MFAM) { + +#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \ + MFAM.registerPass([&] { return PASS_NAME(); }); +#include "llvm/CodeGen/MachinePassRegistry.def" + + for (auto &C : MachineFunctionAnalysisRegistrationCallbacks) + C(MFAM); +} + void PassBuilder::registerLoopAnalyses(LoopAnalysisManager &LAM) { #define LOOP_ANALYSIS(NAME, CREATE_PASS) \ LAM.registerPass([&] { return CREATE_PASS; }); @@ -1877,6 +1894,33 @@ Error PassBuilder::parseLoopPass(LoopPassManager &LPM, inconvertibleErrorCode()); } +Error PassBuilder::parseMachinePass(MachineFunctionPassManager &MFPM, + const PipelineElement &E) { + StringRef Name = E.Name; + if (!E.InnerPipeline.empty()) + return make_error("invalid pipeline", + inconvertibleErrorCode()); + +#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \ + if (Name == NAME) { \ + MFPM.addPass(PASS_NAME()); \ + return Error::success(); \ + } +#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \ + if (Name == NAME) { \ + MFPM.addPass(PASS_NAME()); \ + return Error::success(); \ + } +#include "llvm/CodeGen/MachinePassRegistry.def" + + for (auto &C : MachinePipelineParsingCallbacks) + if (C(Name, MFPM)) + return Error::success(); + return make_error( + formatv("unknown machine pass '{0}'", Name).str(), + inconvertibleErrorCode()); +} + bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) { #define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ if (Name == NAME) { \ @@ -1898,6 +1942,15 @@ bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) { return false; } +Error PassBuilder::parseMachinePassPipeline( + MachineFunctionPassManager &MFPM, ArrayRef Pipeline) { + for (const auto &Element : Pipeline) { + if (auto Err = parseMachinePass(MFPM, Element)) + return Err; + } + return Error::success(); +} + Error PassBuilder::parseLoopPassPipeline(LoopPassManager &LPM, ArrayRef Pipeline) { for (const auto &Element : Pipeline) { @@ -2057,6 +2110,20 @@ Error PassBuilder::parsePassPipeline(LoopPassManager &CGPM, return Error::success(); } +Error PassBuilder::parsePassPipeline(MachineFunctionPassManager &MFPM, + StringRef PipelineText) { + auto Pipeline = parsePipelineText(PipelineText); + if (!Pipeline || Pipeline->empty()) + return make_error( + formatv("invalid machine pass pipeline '{0}'", PipelineText).str(), + inconvertibleErrorCode()); + + if (auto Err = parseMachinePassPipeline(MFPM, *Pipeline)) + return Err; + + return Error::success(); +} + Error PassBuilder::parseAAPipeline(AAManager &AA, StringRef PipelineText) { // If the pipeline just consists of the word 'default' just replace the AA // manager with our default one. @@ -2151,6 +2218,21 @@ void PassBuilder::printPassNames(raw_ostream &OS) { OS << "Loop analyses:\n"; #define LOOP_ANALYSIS(NAME, CREATE_PASS) printPassName(NAME, OS); #include "PassRegistry.def" + + OS << "Machine module passes (WIP):\n"; +#define MACHINE_MODULE_PASS(NAME, PASS_NAME, CONSTRUCTOR) \ + printPassName(NAME, OS); +#include "llvm/CodeGen/MachinePassRegistry.def" + + OS << "Machine function passes (WIP):\n"; +#define MACHINE_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) \ + printPassName(NAME, OS); +#include "llvm/CodeGen/MachinePassRegistry.def" + + OS << "Machine function analyses (WIP):\n"; +#define MACHINE_FUNCTION_ANALYSIS(NAME, PASS_NAME, CONSTRUCTOR) \ + printPassName(NAME, OS); +#include "llvm/CodeGen/MachinePassRegistry.def" } void PassBuilder::registerParseTopLevelPipelineCallback( diff --git a/llvm/unittests/MIR/CMakeLists.txt b/llvm/unittests/MIR/CMakeLists.txt index 3c0e9e43f9afb..f485dcbd971b6 100644 --- a/llvm/unittests/MIR/CMakeLists.txt +++ b/llvm/unittests/MIR/CMakeLists.txt @@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS FileCheck MC MIRParser + Passes Support Target TargetParser @@ -13,6 +14,7 @@ set(LLVM_LINK_COMPONENTS add_llvm_unittest(MIRTests MachineMetadata.cpp + PassBuilderCallbacksTest.cpp ) target_link_libraries(MIRTests PRIVATE LLVMTestingSupport) diff --git a/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp new file mode 100644 index 0000000000000..749e2f76a6f86 --- /dev/null +++ b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp @@ -0,0 +1,499 @@ +//===- unittests/MIR/PassBuilderCallbacksTest.cpp - PB Callback Tests -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Testing/Support/Error.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace llvm; + +namespace { +using testing::_; +using testing::AnyNumber; +using testing::DoAll; +using testing::Not; +using testing::Return; +using testing::WithArgs; + +StringRef MIRString = R"MIR( +--- | + define void @test() { + ret void + } +... +--- +name: test +body: | + bb.0 (%ir-block.0): + RET64 +... +)MIR"; + +/// Helper for HasName matcher that returns getName both for IRUnit and +/// for IRUnit pointer wrapper into llvm::Any (wrapped by PassInstrumentation). +template std::string getName(const IRUnitT &IR) { + return std::string(IR.getName()); +} + +template <> std::string getName(const StringRef &name) { + return std::string(name); +} + +template <> std::string getName(const Any &WrappedIR) { + if (const auto *const *M = llvm::any_cast(&WrappedIR)) + return (*M)->getName().str(); + if (const auto *const *F = llvm::any_cast(&WrappedIR)) + return (*F)->getName().str(); + if (const auto *const *MF = + llvm::any_cast(&WrappedIR)) + return (*MF)->getName().str(); + return ""; +} +/// Define a custom matcher for objects which support a 'getName' method. +/// +/// LLVM often has IR objects or analysis objects which expose a name +/// and in tests it is convenient to match these by name for readability. +/// Usually, this name is either a StringRef or a plain std::string. This +/// matcher supports any type exposing a getName() method of this form whose +/// return value is compatible with an std::ostream. For StringRef, this uses +/// the shift operator defined above. +/// +/// It should be used as: +/// +/// HasName("my_function") +/// +/// No namespace or other qualification is required. +MATCHER_P(HasName, Name, "") { + *result_listener << "has name '" << getName(arg) << "'"; + return Name == getName(arg); +} + +MATCHER_P(HasNameRegex, Name, "") { + *result_listener << "has name '" << getName(arg) << "'"; + llvm::Regex r(Name); + return r.match(getName(arg)); +} + +struct MockPassInstrumentationCallbacks { + PassInstrumentationCallbacks Callbacks; + + MockPassInstrumentationCallbacks() { + ON_CALL(*this, runBeforePass(_, _)).WillByDefault(Return(true)); + } + MOCK_METHOD2(runBeforePass, bool(StringRef PassID, llvm::Any)); + MOCK_METHOD2(runBeforeSkippedPass, void(StringRef PassID, llvm::Any)); + MOCK_METHOD2(runBeforeNonSkippedPass, void(StringRef PassID, llvm::Any)); + MOCK_METHOD3(runAfterPass, + void(StringRef PassID, llvm::Any, const PreservedAnalyses &PA)); + MOCK_METHOD2(runAfterPassInvalidated, + void(StringRef PassID, const PreservedAnalyses &PA)); + MOCK_METHOD2(runBeforeAnalysis, void(StringRef PassID, llvm::Any)); + MOCK_METHOD2(runAfterAnalysis, void(StringRef PassID, llvm::Any)); + + void registerPassInstrumentation() { + Callbacks.registerShouldRunOptionalPassCallback( + [this](StringRef P, llvm::Any IR) { + return this->runBeforePass(P, IR); + }); + Callbacks.registerBeforeSkippedPassCallback( + [this](StringRef P, llvm::Any IR) { + this->runBeforeSkippedPass(P, IR); + }); + Callbacks.registerBeforeNonSkippedPassCallback( + [this](StringRef P, llvm::Any IR) { + this->runBeforeNonSkippedPass(P, IR); + }); + Callbacks.registerAfterPassCallback( + [this](StringRef P, llvm::Any IR, const PreservedAnalyses &PA) { + this->runAfterPass(P, IR, PA); + }); + Callbacks.registerAfterPassInvalidatedCallback( + [this](StringRef P, const PreservedAnalyses &PA) { + this->runAfterPassInvalidated(P, PA); + }); + Callbacks.registerBeforeAnalysisCallback([this](StringRef P, llvm::Any IR) { + return this->runBeforeAnalysis(P, IR); + }); + Callbacks.registerAfterAnalysisCallback( + [this](StringRef P, llvm::Any IR) { this->runAfterAnalysis(P, IR); }); + } + + void ignoreNonMockPassInstrumentation(StringRef IRName) { + // Generic EXPECT_CALLs are needed to match instrumentation on unimportant + // parts of a pipeline that we do not care about (e.g. various passes added + // by default by PassBuilder - Verifier pass etc). + // Make sure to avoid ignoring Mock passes/analysis, we definitely want + // to check these explicitly. + EXPECT_CALL(*this, + runBeforePass(Not(HasNameRegex("Mock")), HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL( + *this, runBeforeSkippedPass(Not(HasNameRegex("Mock")), HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL(*this, runBeforeNonSkippedPass(Not(HasNameRegex("Mock")), + HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL(*this, + runAfterPass(Not(HasNameRegex("Mock")), HasName(IRName), _)) + .Times(AnyNumber()); + EXPECT_CALL(*this, runBeforeAnalysis(HasNameRegex("MachineModuleAnalysis"), + HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL(*this, + runBeforeAnalysis(Not(HasNameRegex("Mock")), HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL(*this, runAfterAnalysis(HasNameRegex("MachineModuleAnalysis"), + HasName(IRName))) + .Times(AnyNumber()); + EXPECT_CALL(*this, + runAfterAnalysis(Not(HasNameRegex("Mock")), HasName(IRName))) + .Times(AnyNumber()); + } +}; + +template class MockAnalysisHandleBase { +public: + class Analysis : public AnalysisInfoMixin { + friend AnalysisInfoMixin; + friend MockAnalysisHandleBase; + static AnalysisKey Key; + + DerivedT *Handle; + + Analysis(DerivedT &Handle) : Handle(&Handle) { + static_assert(std::is_base_of::value, + "Must pass the derived type to this template!"); + } + + public: + class Result { + friend MockAnalysisHandleBase; + + DerivedT *Handle; + + Result(DerivedT &Handle) : Handle(&Handle) {} + + public: + // Forward invalidation events to the mock handle. + bool invalidate(MachineFunction &IR, const PreservedAnalyses &PA, + MachineFunctionAnalysisManager::Invalidator &Inv) { + return Handle->invalidate(IR, PA, Inv); + } + }; + + Result run(MachineFunction &IR, MachineFunctionAnalysisManager::Base &AM) { + return Handle->run(IR, AM); + } + }; + + Analysis getAnalysis() { return Analysis(static_cast(*this)); } + typename Analysis::Result getResult() { + return typename Analysis::Result(static_cast(*this)); + } + static StringRef getName() { return llvm::getTypeName(); } + +protected: + // FIXME: MSVC seems unable to handle a lambda argument to Invoke from within + // the template, so we use a boring static function. + static bool + invalidateCallback(MachineFunction &IR, const PreservedAnalyses &PA, + MachineFunctionAnalysisManager::Invalidator &Inv) { + auto PAC = PA.template getChecker(); + return !PAC.preserved() && + !PAC.template preservedSet>(); + } + + /// Derived classes should call this in their constructor to set up default + /// mock actions. (We can't do this in our constructor because this has to + /// run after the DerivedT is constructed.) + void setDefaults() { + ON_CALL(static_cast(*this), run(_, _)) + .WillByDefault(Return(this->getResult())); + ON_CALL(static_cast(*this), invalidate(_, _, _)) + .WillByDefault(&invalidateCallback); + } +}; + +template class MockPassHandleBase { +public: + class Pass : public MachinePassInfoMixin { + friend MockPassHandleBase; + + DerivedT *Handle; + + Pass(DerivedT &Handle) : Handle(&Handle) { + static_assert(std::is_base_of::value, + "Must pass the derived type to this template!"); + } + + public: + static MachinePassKey Key; + PreservedAnalyses run(MachineFunction &IR, + MachineFunctionAnalysisManager::Base &AM) { + return Handle->run(IR, AM); + } + }; + + static StringRef getName() { return llvm::getTypeName(); } + + Pass getPass() { return Pass(static_cast(*this)); } + +protected: + /// Derived classes should call this in their constructor to set up default + /// mock actions. (We can't do this in our constructor because this has to + /// run after the DerivedT is constructed.) + void setDefaults() { + ON_CALL(static_cast(*this), run(_, _)) + .WillByDefault(Return(PreservedAnalyses::all())); + } +}; + +struct MockAnalysisHandle : public MockAnalysisHandleBase { + MOCK_METHOD2(run, Analysis::Result(MachineFunction &, + MachineFunctionAnalysisManager::Base &)); + + MOCK_METHOD3(invalidate, bool(MachineFunction &, const PreservedAnalyses &, + MachineFunctionAnalysisManager::Invalidator &)); + + MockAnalysisHandle() { setDefaults(); } +}; + +template +MachinePassKey MockPassHandleBase::Pass::Key; + +template +AnalysisKey MockAnalysisHandleBase::Analysis::Key; + +class MockPassHandle : public MockPassHandleBase { +public: + MOCK_METHOD2(run, PreservedAnalyses(MachineFunction &, + MachineFunctionAnalysisManager::Base &)); + + MockPassHandle() { setDefaults(); } +}; + +class MachineFunctionCallbacksTest : public testing::Test { +protected: + static void SetUpTestCase() { + InitializeAllTargetInfos(); + InitializeAllTargets(); + InitializeAllTargetMCs(); + } + + TargetMachine *TM; + + LLVMContext Context; + std::unique_ptr M; + std::unique_ptr MIR; + + MockPassInstrumentationCallbacks CallbacksHandle; + + PassBuilder PB; + ModulePassManager PM; + MachineFunctionPassManager MFPM; + FunctionAnalysisManager FAM; + ModuleAnalysisManager AM; + MachineFunctionAnalysisManager MFAM; + + MockPassHandle PassHandle; + MockAnalysisHandle AnalysisHandle; + + std::unique_ptr parseMIR(const TargetMachine &TM, StringRef MIRCode, + MachineModuleInfo &MMI) { + SMDiagnostic Diagnostic; + std::unique_ptr MBuffer = MemoryBuffer::getMemBuffer(MIRCode); + MIR = createMIRParser(std::move(MBuffer), Context); + if (!MIR) + return nullptr; + + std::unique_ptr Mod = MIR->parseIRModule(); + if (!Mod) + return nullptr; + + Mod->setDataLayout(TM.createDataLayout()); + + if (MIR->parseMachineFunctions(*Mod, MMI)) { + M.reset(); + return nullptr; + } + return Mod; + } + + static PreservedAnalyses + getAnalysisResult(MachineFunction &U, + MachineFunctionAnalysisManager::Base &AM) { + auto &MFAM = static_cast(AM); + MFAM.getResult(U); + return PreservedAnalyses::all(); + } + + void SetUp() override { + std::string Error; + auto TripleName = "x86_64-pc-linux-gnu"; + auto *T = TargetRegistry::lookupTarget(TripleName, Error); + if (!T) + GTEST_SKIP(); + TM = T->createTargetMachine(TripleName, "", "", TargetOptions(), + std::nullopt); + MachineModuleInfo MMI(static_cast(TM)); + M = parseMIR(*TM, MIRString, MMI); + AM.registerPass([&] { + return MachineModuleAnalysis(static_cast(TM)); + }); + } + + MachineFunctionCallbacksTest() + : CallbacksHandle(), PB(nullptr, PipelineTuningOptions(), std::nullopt, + &CallbacksHandle.Callbacks), + PM(), FAM(), AM(), MFAM(FAM, AM) { + + EXPECT_TRUE(&CallbacksHandle.Callbacks == + PB.getPassInstrumentationCallbacks()); + + /// Register a callback for analysis registration. + /// + /// The callback is a function taking a reference to an AnalyisManager + /// object. When called, the callee gets to register its own analyses with + /// this PassBuilder instance. + PB.registerAnalysisRegistrationCallback( + [this](MachineFunctionAnalysisManager &AM) { + // Register our mock analysis + AM.registerPass([this] { return AnalysisHandle.getAnalysis(); }); + }); + + /// Register a callback for pipeline parsing. + /// + /// During parsing of a textual pipeline, the PassBuilder will call these + /// callbacks for each encountered pass name that it does not know. This + /// includes both simple pass names as well as names of sub-pipelines. In + /// the latter case, the InnerPipeline is not empty. + PB.registerPipelineParsingCallback( + [this](StringRef Name, MachineFunctionPassManager &PM) { + if (parseAnalysisUtilityPasses( + "test-analysis", Name, PM)) + return true; + + /// Parse the name of our pass mock handle + if (Name == "test-transform") { + MFPM.addPass(PassHandle.getPass()); + return true; + } + return false; + }); + + /// Register builtin analyses and cross-register the analysis proxies + PB.registerModuleAnalyses(AM); + PB.registerFunctionAnalyses(FAM); + PB.registerMachineFunctionAnalyses(MFAM); + } +}; + +TEST_F(MachineFunctionCallbacksTest, Passes) { + EXPECT_CALL(AnalysisHandle, run(HasName("test"), _)); + EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult); + + StringRef PipelineText = "test-transform"; + ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded()) + << "Pipeline was: " << PipelineText; + ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded()); +} + +TEST_F(MachineFunctionCallbacksTest, InstrumentedPasses) { + CallbacksHandle.registerPassInstrumentation(); + // Non-mock instrumentation not specifically mentioned below can be ignored. + CallbacksHandle.ignoreNonMockPassInstrumentation(""); + CallbacksHandle.ignoreNonMockPassInstrumentation("test"); + CallbacksHandle.ignoreNonMockPassInstrumentation(""); + + // PassInstrumentation calls should happen in-sequence, in the same order + // as passes/analyses are scheduled. + ::testing::Sequence PISequence; + EXPECT_CALL(CallbacksHandle, + runBeforePass(HasNameRegex("MockPassHandle"), HasName("test"))) + .InSequence(PISequence); + EXPECT_CALL( + CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), HasName("test"))) + .InSequence(PISequence); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("MockPassHandle"), HasName("test"), _)) + .InSequence(PISequence); + + EXPECT_CALL(AnalysisHandle, run(HasName("test"), _)); + EXPECT_CALL(PassHandle, run(HasName("test"), _)).WillOnce(&getAnalysisResult); + + StringRef PipelineText = "test-transform"; + ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded()) + << "Pipeline was: " << PipelineText; + ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded()); +} + +TEST_F(MachineFunctionCallbacksTest, InstrumentedSkippedPasses) { + CallbacksHandle.registerPassInstrumentation(); + // Non-mock instrumentation run here can safely be ignored. + CallbacksHandle.ignoreNonMockPassInstrumentation(""); + CallbacksHandle.ignoreNonMockPassInstrumentation("test"); + CallbacksHandle.ignoreNonMockPassInstrumentation(""); + + // Skip the pass by returning false. + EXPECT_CALL(CallbacksHandle, + runBeforePass(HasNameRegex("MockPassHandle"), HasName("test"))) + .WillOnce(Return(false)); + + EXPECT_CALL( + CallbacksHandle, + runBeforeSkippedPass(HasNameRegex("MockPassHandle"), HasName("test"))) + .Times(1); + + EXPECT_CALL(AnalysisHandle, run(HasName("test"), _)).Times(0); + EXPECT_CALL(PassHandle, run(HasName("test"), _)).Times(0); + + // As the pass is skipped there is no afterPass, beforeAnalysis/afterAnalysis + // as well. + EXPECT_CALL(CallbacksHandle, + runBeforeNonSkippedPass(HasNameRegex("MockPassHandle"), _)) + .Times(0); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("MockPassHandle"), _, _)) + .Times(0); + EXPECT_CALL(CallbacksHandle, + runAfterPassInvalidated(HasNameRegex("MockPassHandle"), _)) + .Times(0); + EXPECT_CALL(CallbacksHandle, + runAfterPass(HasNameRegex("MockPassHandle"), _, _)) + .Times(0); + EXPECT_CALL(CallbacksHandle, + runBeforeAnalysis(HasNameRegex("MockAnalysisHandle"), _)) + .Times(0); + EXPECT_CALL(CallbacksHandle, + runAfterAnalysis(HasNameRegex("MockAnalysisHandle"), _)) + .Times(0); + + StringRef PipelineText = "test-transform"; + ASSERT_THAT_ERROR(PB.parsePassPipeline(MFPM, PipelineText), Succeeded()) + << "Pipeline was: " << PipelineText; + ASSERT_THAT_ERROR(MFPM.run(*M, MFAM), Succeeded()); +} + +} // end anonymous namespace From c74a50fc363d08d668b9e5f3fcb399b5788d71e7 Mon Sep 17 00:00:00 2001 From: PaperChalice Date: Fri, 12 Jan 2024 12:41:50 +0800 Subject: [PATCH 2/2] Wrap TargetMachine by std::unique_ptr --- llvm/unittests/MIR/PassBuilderCallbacksTest.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp index 749e2f76a6f86..5ab4df1c26df3 100644 --- a/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp +++ b/llvm/unittests/MIR/PassBuilderCallbacksTest.cpp @@ -301,7 +301,7 @@ class MachineFunctionCallbacksTest : public testing::Test { InitializeAllTargetMCs(); } - TargetMachine *TM; + std::unique_ptr TM; LLVMContext Context; std::unique_ptr M; @@ -354,13 +354,14 @@ class MachineFunctionCallbacksTest : public testing::Test { auto *T = TargetRegistry::lookupTarget(TripleName, Error); if (!T) GTEST_SKIP(); - TM = T->createTargetMachine(TripleName, "", "", TargetOptions(), - std::nullopt); - MachineModuleInfo MMI(static_cast(TM)); + TM = std::unique_ptr( + static_cast(T->createTargetMachine( + TripleName, "", "", TargetOptions(), std::nullopt))); + if (!TM) + GTEST_SKIP(); + MachineModuleInfo MMI(TM.get()); M = parseMIR(*TM, MIRString, MMI); - AM.registerPass([&] { - return MachineModuleAnalysis(static_cast(TM)); - }); + AM.registerPass([&] { return MachineModuleAnalysis(TM.get()); }); } MachineFunctionCallbacksTest()