Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: revive CachedMethodTable mechanism #46535

Merged
merged 1 commit into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
if result === missing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
matches, overlayed = result
(; matches, overlayed) = result
nonoverlayed &= !overlayed
push!(infos, MethodMatchInfo(matches))
for m in matches
Expand Down Expand Up @@ -334,7 +334,7 @@ function find_matching_methods(argtypes::Vector{Any}, @nospecialize(atype), meth
# (assume this will always be true, so we don't compute / update valid age in this case)
return FailedMethodMatch("Too many methods matched")
end
matches, overlayed = result
(; matches, overlayed) = result
fullmatch = _any(match->(match::MethodMatch).fully_covers, matches)
return MethodMatches(matches.matches,
MethodMatchInfo(matches),
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ something(x::Any, y...) = x
############

include("compiler/cicache.jl")
include("compiler/methodtable.jl")
include("compiler/effects.jl")
include("compiler/types.jl")
include("compiler/utilities.jl")
include("compiler/validation.jl")
include("compiler/methodtable.jl")

function argextype end # imported by EscapeAnalysis
function stmt_effect_free end # imported by EscapeAnalysis
Expand Down
80 changes: 57 additions & 23 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,27 @@

abstract type MethodTableView; end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

struct MethodMatchResult
matches::MethodLookupResult
overlayed::Bool
end

"""
struct InternalMethodTable <: MethodTableView

Expand All @@ -23,25 +44,21 @@ struct OverlayMethodTable <: MethodTableView
mt::Core.MethodTable
end

struct MethodLookupResult
# Really Vector{Core.MethodMatch}, but it's easier to represent this as
# and work with Vector{Any} on the C side.
matches::Vector{Any}
valid_worlds::WorldRange
ambig::Bool
end
length(result::MethodLookupResult) = length(result.matches)
function iterate(result::MethodLookupResult, args...)
r = iterate(result.matches, args...)
r === nothing && return nothing
match, state = r
return (match::MethodMatch, state)
"""
struct CachedMethodTable <: MethodTableView

Overlays another method table view with an additional local fast path cache that
can respond to repeated, identical queries faster than the original method table.
"""
struct CachedMethodTable{T} <: MethodTableView
cache::IdDict{Any, Union{Missing, MethodMatchResult}}
table::T
end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
CachedMethodTable(table::T) where T = CachedMethodTable{T}(IdDict{Any, Union{Missing, MethodMatchResult}}(), table)

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing
MethodMatchResult(matches::MethodLookupResult, overlayed::Bool) or missing

Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
Expand All @@ -51,7 +68,7 @@ If the number of applicable methods exceeded the specified limit, `missing` is r
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
return MethodMatchResult(result, false)
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -60,18 +77,20 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr ≥ 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result, true
return MethodMatchResult(result, true)
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
fallback_result === missing && return missing
# merge the fallback match results with the internal method table
return MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig), !isempty(result)
return MethodMatchResult(
MethodLookupResult(
vcat(result.matches, fallback_result.matches),
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig),
!isempty(result))
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -85,6 +104,17 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
return MethodLookupResult(ms::Vector{Any}, WorldRange(_min_val[], _max_val[]), _ambig[] != 0)
end

function findall(@nospecialize(sig::Type), table::CachedMethodTable; limit::Int=typemax(Int))
if isconcretetype(sig)
# as for concrete types, we cache result at on the next level
return findall(sig, table.table; limit)
end
box = Core.Box(sig)
return get!(table.cache, sig) do
findall(box.contents, table.table; limit)
end
end

"""
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
Expand Down Expand Up @@ -129,6 +159,10 @@ function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
return match, valid_worlds
end

# This query is not cached
findsup(@nospecialize(sig::Type), table::CachedMethodTable) = findsup(sig, table.table)

isoverlayed(::MethodTableView) = error("unsatisfied MethodTableView interface")
isoverlayed(::InternalMethodTable) = false
isoverlayed(::OverlayMethodTable) = true
isoverlayed(mt::CachedMethodTable) = isoverlayed(mt.table)
19 changes: 8 additions & 11 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ struct NativeInterpreter <: AbstractInterpreter
cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt
# method table to lookup for during inference on this world age
method_table::CachedMethodTable{InternalMethodTable}

# Parameters for inference and optimization
inf_params::InferenceParams
Expand All @@ -167,27 +169,21 @@ struct NativeInterpreter <: AbstractInterpreter
inf_params = InferenceParams(),
opt_params = OptimizationParams(),
)
cache = Vector{InferenceResult}() # Initially empty cache

# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age
if world == typemax(UInt)
world = get_world_counter()
end

method_table = CachedMethodTable(InternalMethodTable(world))

# If they didn't pass typemax(UInt) but passed something more subtly
# incorrect, fail out loudly.
@assert world <= get_world_counter()

return new(
# Initially empty cache
Vector{InferenceResult}(),

# world age counter
world,

# parameters for inference and optimization
inf_params,
opt_params,
)
return new(cache, world, method_table, inf_params, opt_params)
end
end

Expand Down Expand Up @@ -251,6 +247,7 @@ External `AbstractInterpreter` can optionally return `OverlayMethodTable` here
to incorporate customized dispatches for the overridden methods.
"""
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_world_counter(interp))
method_table(interp::NativeInterpreter) = interp.method_table

"""
By default `AbstractInterpreter` implements the following inference bail out logic:
Expand Down