From 5d03fcd948ad862458f2b7b09cc97ad66c08ac1f Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 29 Aug 2022 21:04:18 +0900 Subject: [PATCH] inference: revive `CachedMethodTable` mechanism `CachedMethodTable` was removed within #44240 as we couldn't confirm any performance improvement then. However it turns out the optimization was critical in some real world cases (e.g. #46492), so this commit revives the mechanism with the following tweaks that should make it more effective: - create method table cache per inference (rather than per local inference on a function call as on the previous implementation) - only use cache mechanism for abstract types (since we already cache lookup result at the next level as for concrete types) As a result, the following snippet reported at #46492 recovers the compilation performance: ```julia using ControlSystems a_2 = [-5 -3; 2 -9] C_212 = ss(a_2, [1; 2], [1 0; 0 1], [0; 0]) @time norm(C_212) ``` > on master ``` julia> @time norm(C_212) 364.489044 seconds (724.44 M allocations: 92.524 GiB, 6.01% gc time, 100.00% compilation time) 0.5345224838248489 ``` > on this commit ``` julia> @time norm(C_212) 26.539016 seconds (62.09 M allocations: 5.537 GiB, 5.55% gc time, 100.00% compilation time) 0.5345224838248489 ``` (cherry picked from commit 844574411fc77f5de1528a2ab30b9457238959cf) --- base/compiler/abstractinterpretation.jl | 4 +- base/compiler/compiler.jl | 2 +- base/compiler/methodtable.jl | 80 ++++++++++++++++++------- base/compiler/types.jl | 19 +++--- 4 files changed, 68 insertions(+), 37 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5c1133b4d40ec..98bcfa4f18661 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -282,7 +282,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth if result === missing return FailedMethodMatch("For one of the union split cases, too many methods matched") end - matches, overlayed = result + (; matches, overlayed) = result nonoverlayed &= !overlayed push!(infos, MethodMatchInfo(matches)) for m in matches @@ -323,7 +323,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth # (assume this will always be true, so we don't compute / update valid age in this case) return FailedMethodMatch("Too many methods matched") end - matches, overlayed = result + (; matches, overlayed) = result fullmatch = _any(match->(match::MethodMatch).fully_covers, matches) return MethodMatches(matches.matches, MethodMatchInfo(matches), diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 6991e2d38437b..ed88e8c22178f 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -123,10 +123,10 @@ something(x::Any, y...) = x ############ include("compiler/cicache.jl") +include("compiler/methodtable.jl") include("compiler/types.jl") include("compiler/utilities.jl") include("compiler/validation.jl") -include("compiler/methodtable.jl") include("compiler/inferenceresult.jl") include("compiler/inferencestate.jl") diff --git a/base/compiler/methodtable.jl b/base/compiler/methodtable.jl index 7aa686009c1af..8b3968332e2e8 100644 --- a/base/compiler/methodtable.jl +++ b/base/compiler/methodtable.jl @@ -2,6 +2,27 @@ abstract type MethodTableView; end +struct MethodLookupResult + # Really Vector{Core.MethodMatch}, but it's easier to represent this as + # and work with Vector{Any} on the C side. + matches::Vector{Any} + valid_worlds::WorldRange + ambig::Bool +end +length(result::MethodLookupResult) = length(result.matches) +function iterate(result::MethodLookupResult, args...) + r = iterate(result.matches, args...) + r === nothing && return nothing + match, state = r + return (match::MethodMatch, state) +end +getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch + +struct MethodMatchResult + matches::MethodLookupResult + overlayed::Bool +end + """ struct InternalMethodTable <: MethodTableView @@ -23,25 +44,21 @@ struct OverlayMethodTable <: MethodTableView mt::Core.MethodTable end -struct MethodLookupResult - # Really Vector{Core.MethodMatch}, but it's easier to represent this as - # and work with Vector{Any} on the C side. - matches::Vector{Any} - valid_worlds::WorldRange - ambig::Bool -end -length(result::MethodLookupResult) = length(result.matches) -function iterate(result::MethodLookupResult, args...) - r = iterate(result.matches, args...) - r === nothing && return nothing - match, state = r - return (match::MethodMatch, state) +""" + struct CachedMethodTable <: MethodTableView + +Overlays another method table view with an additional local fast path cache that +can respond to repeated, identical queries faster than the original method table. +""" +struct CachedMethodTable{T} <: MethodTableView + cache::IdDict{Any, Union{Missing, MethodMatchResult}} + table::T end -getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch +CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodMatchResult}}(), table) """ findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> - (matches::MethodLookupResult, overlayed::Bool) or missing + MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or missing Find all methods in the given method table `view` that are applicable to the given signature `sig`. If no applicable methods are found, an empty result is returned. @@ -51,7 +68,7 @@ If the number of applicable methods exceeded the specified limit, `missing` is r function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32))) result = _findall(sig, nothing, table.world, limit) result === missing && return missing - return result, false + return MethodMatchResult(result, false) end function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32))) @@ -60,18 +77,20 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int nr = length(result) if nr ≥ 1 && result[nr].fully_covers # no need to fall back to the internal method table - return result, true + return MethodMatchResult(result, true) end # fall back to the internal method table fallback_result = _findall(sig, nothing, table.world, limit) fallback_result === missing && return missing # merge the fallback match results with the internal method table - return MethodLookupResult( - vcat(result.matches, fallback_result.matches), - WorldRange( - max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), - min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), - result.ambig | fallback_result.ambig), !isempty(result) + return MethodMatchResult( + MethodLookupResult( + vcat(result.matches, fallback_result.matches), + WorldRange( + max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world), + min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)), + result.ambig | fallback_result.ambig), + !isempty(result)) end function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int) @@ -85,6 +104,17 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0) end +function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int)) + if isconcretetype(sig) + # as for concrete types, we cache result at on the next level + return findall(sig, table.table; limit) + end + box = Core.Box(sig) + return get!(table.cache, sig) do + findall(box.contents, table.table; limit) + end +end + """ findsup(sig::Type, view::MethodTableView) -> (match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing @@ -129,6 +159,10 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, return match, valid_worlds end +# This query is not cached +findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table) + isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface") isoverlayed(::InternalMethodTable) = false isoverlayed(::OverlayMethodTable) = true +isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index a0884bd86d1d3..3652a15c3e7bd 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -318,6 +318,8 @@ struct NativeInterpreter <: AbstractInterpreter cache::Vector{InferenceResult} # The world age we're working inside of world::UInt + # method table to lookup for during inference on this world age + method_table::CachedMethodTable{InternalMethodTable} # Parameters for inference and optimization inf_params::InferenceParams @@ -327,27 +329,21 @@ struct NativeInterpreter <: AbstractInterpreter inf_params = InferenceParams(), opt_params = OptimizationParams(), ) + cache = Vector{InferenceResult}() # Initially empty cache + # Sometimes the caller is lazy and passes typemax(UInt). # we cap it to the current world age if world == typemax(UInt) world = get_world_counter() end + method_table = CachedMethodTable(InternalMethodTable(world)) + # If they didn't pass typemax(UInt) but passed something more subtly # incorrect, fail out loudly. @assert world <= get_world_counter() - return new( - # Initially empty cache - Vector{InferenceResult}(), - - # world age counter - world, - - # parameters for inference and optimization - inf_params, - opt_params, - ) + return new(cache, world, method_table, inf_params, opt_params) end end @@ -396,6 +392,7 @@ External `AbstractInterpreter` can optionally return `OverlayMethodTable` here to incorporate customized dispatches for the overridden methods. """ method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp)) +method_table(interp::NativeInterpreter) = interp.method_table """ By default `AbstractInterpreter` implements the following inference bail out logic: