diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp index f7cecb57f5d35..cb98ed838f5d7 100644 --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -572,8 +572,8 @@ void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, struct DevirtModule { Module &M; - function_ref AARGetter; - function_ref LookupDomTree; + ModuleAnalysisManager &MAM; + FunctionAnalysisManager &FAM; ModuleSummaryIndex *const ExportSummary; const ModuleSummaryIndex *const ImportSummary; @@ -589,7 +589,7 @@ struct DevirtModule { ArrayType *const Int8Arr0Ty; const bool RemarksEnabled; - function_ref OREGetter; + std::function OREGetter; MapVector CallSlots; // Calls that have already been optimized. We may add a call to multiple @@ -612,12 +612,11 @@ struct DevirtModule { std::map NumUnsafeUsesForTypeTest; PatternList FunctionsToSkip; - DevirtModule(Module &M, function_ref AARGetter, - function_ref OREGetter, - function_ref LookupDomTree, + DevirtModule(Module &M, ModuleAnalysisManager &MAM, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary) - : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), + : M(M), MAM(MAM), + FAM(MAM.getResult(M).getManager()), ExportSummary(ExportSummary), ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(PointerType::getUnqual(M.getContext())), @@ -625,7 +624,10 @@ struct DevirtModule { Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), - RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { + RemarksEnabled(areRemarksEnabled()), + OREGetter([&](Function &F) -> OptimizationRemarkEmitter & { + return FAM.getResult(F); + }) { assert(!(ExportSummary && ImportSummary)); FunctionsToSkip.init(SkipFunctionNames); } @@ -739,10 +741,7 @@ struct DevirtModule { // Lower the module using the action and summary passed as command line // arguments. For testing purposes only. - static bool - runForTesting(Module &M, function_ref AARGetter, - function_ref OREGetter, - function_ref LookupDomTree); + static bool runForTesting(Module &M, ModuleAnalysisManager &MAM); }; struct DevirtIndex { @@ -783,25 +782,13 @@ struct DevirtIndex { } // end anonymous namespace PreservedAnalyses WholeProgramDevirtPass::run(Module &M, - ModuleAnalysisManager &AM) { - auto &FAM = AM.getResult(M).getManager(); - auto AARGetter = [&](Function &F) -> AAResults & { - return FAM.getResult(F); - }; - auto OREGetter = [&](Function &F) -> OptimizationRemarkEmitter & { - return FAM.getResult(F); - }; - auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { - return FAM.getResult(F); - }; + ModuleAnalysisManager &MAM) { if (UseCommandLine) { - if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) + if (!DevirtModule::runForTesting(M, MAM)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } - if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, - ImportSummary) - .run()) + if (!DevirtModule(M, MAM, ExportSummary, ImportSummary).run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } @@ -996,10 +983,7 @@ static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { return ErrorSuccess(); } -bool DevirtModule::runForTesting( - Module &M, function_ref AARGetter, - function_ref OREGetter, - function_ref LookupDomTree) { +bool DevirtModule::runForTesting(Module &M, ModuleAnalysisManager &MAM) { std::unique_ptr Summary = std::make_unique(/*HaveGVs=*/false); @@ -1024,7 +1008,7 @@ bool DevirtModule::runForTesting( } bool Changed = - DevirtModule(M, AARGetter, OREGetter, LookupDomTree, + DevirtModule(M, MAM, ClSummaryAction == PassSummaryAction::Export ? Summary.get() : nullptr, ClSummaryAction == PassSummaryAction::Import ? Summary.get() @@ -1877,7 +1861,7 @@ bool DevirtModule::tryVirtualConstProp( return false; if (Fn->isDeclaration() || - !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn)) + !computeFunctionBodyMemoryAccess(*Fn, FAM.getResult(*Fn)) .doesNotAccessMemory() || Fn->arg_empty() || !Fn->arg_begin()->use_empty() || Fn->getReturnType() != RetType) @@ -2051,7 +2035,7 @@ void DevirtModule::scanTypeTestUsers( // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector DevirtCalls; SmallVector Assumes; - auto &DT = LookupDomTree(*CI->getFunction()); + auto &DT = FAM.getResult(*CI->getFunction()); findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); Metadata *TypeId = @@ -2128,7 +2112,7 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { SmallVector LoadedPtrs; SmallVector Preds; bool HasNonCallUses = false; - auto &DT = LookupDomTree(*CI->getFunction()); + auto &DT = FAM.getResult(*CI->getFunction()); findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, HasNonCallUses, CI, DT);