Skip to content

Commit

Permalink
inference: revive CachedMethodTable mechanism (#46535)
Browse files Browse the repository at this point in the history
`CachedMethodTable` was removed within #44240 as we couldn't confirm any
performance improvement then. However it turns out the optimization was
critical in some real world cases (e.g. #46492), so this commit revives
the mechanism with the following tweaks that should make it more effective:
- create method table cache per inference (rather than per local
  inference on a function call as on the previous implementation)
- only use cache mechanism for abstract types (since we already cache
  lookup result at the next level as for concrete types)

As a result, the following snippet reported at #46492 recovers the
compilation performance:
```julia
using ControlSystems
a_2 = [-5 -3; 2 -9]
C_212 = ss(a_2, [1; 2], [1 0; 0 1], [0; 0])
@time norm(C_212)
```

> on master
```
julia> @time norm(C_212)
364.489044 seconds (724.44 M allocations: 92.524 GiB, 6.01% gc time, 100.00% compilation time)
0.5345224838248489
```

> on this commit
```
julia> @time norm(C_212)
 26.539016 seconds (62.09 M allocations: 5.537 GiB, 5.55% gc time, 100.00% compilation time)
0.5345224838248489
```
  • Loading branch information
aviatesk authored Aug 31, 2022
1 parent 97c853a commit f066855
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 37 deletions.
4 changes: 2 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
if result === missing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
matches, overlayed = result
(; matches, overlayed) = result
nonoverlayed &= !overlayed
push!(infos, MethodMatchInfo(matches))
for m in matches
Expand Down Expand Up @@ -334,7 +334,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
matches, overlayed = result
(; matches, overlayed) = result
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ something(x::Any, y...) = x
############

include("compiler/cicache.jl")
include("compiler/methodtable.jl")
include("compiler/effects.jl")
include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")
include("compiler/methodtable.jl")

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
Expand Down
80 changes: 57 additions & 23 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@

abstract type MethodTableView; end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

struct MethodMatchResult
matches::MethodLookupResult
overlayed::Bool
end

"""
struct InternalMethodTable <: MethodTableView
Expand All @@ -23,25 +44,21 @@ struct OverlayMethodTable <: MethodTableView
mt::Core.MethodTable
end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
"""
struct CachedMethodTable <: MethodTableView
Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T} <: MethodTableView
cache::IdDict{Any, Union{Missing, MethodMatchResult}}
table::T
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodMatchResult}}(), table)

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or missing
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
Expand All @@ -51,7 +68,7 @@ If the number of applicable methods exceeded the specified limit, `missing` is r
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
return MethodMatchResult(result, false)
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -60,18 +77,20 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result, true
return MethodMatchResult(result, true)
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
fallback_result === missing && return missing
# merge the fallback match results with the internal method table
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig), !isempty(result)
return MethodMatchResult(
MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig),
!isempty(result))
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -85,6 +104,17 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
if isconcretetype(sig)
# as for concrete types, we cache result at on the next level
return findall(sig, table.table; limit)
end
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit)
end
end

"""
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
Expand Down Expand Up @@ -129,6 +159,10 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
return match, valid_worlds
end

# This query is not cached
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)
19 changes: 8 additions & 11 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ struct NativeInterpreter <: AbstractInterpreter
cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt
# method table to lookup for during inference on this world age
method_table::CachedMethodTable{InternalMethodTable}

# Parameters for inference and optimization
inf_params::InferenceParams
Expand All @@ -167,27 +169,21 @@ struct NativeInterpreter <: AbstractInterpreter
inf_params = InferenceParams(),
opt_params = OptimizationParams(),
)
cache = Vector{InferenceResult}() # Initially empty cache

# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age
if world == typemax(UInt)
world = get_world_counter()
end

method_table = CachedMethodTable(InternalMethodTable(world))

# If they didn't pass typemax(UInt) but passed something more subtly
# incorrect, fail out loudly.
@assert world <= get_world_counter()

return new(
# Initially empty cache
Vector{InferenceResult}(),

# world age counter
world,

# parameters for inference and optimization
inf_params,
opt_params,
)
return new(cache, world, method_table, inf_params, opt_params)
end
end

Expand Down Expand Up @@ -251,6 +247,7 @@ External `AbstractInterpreter` can optionally return `OverlayMethodTable` here
to incorporate customized dispatches for the overridden methods.
"""
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
method_table(interp::NativeInterpreter) = interp.method_table

"""
By default `AbstractInterpreter` implements the following inference bail out logic:
Expand Down

0 comments on commit f066855

Please sign in to comment.