Skip to content

Commit

Permalink
allow external absint to hold custom data in codeinst.inferred (#53300
Browse files Browse the repository at this point in the history
)

It has been possible for external abstract interpreters to keep custom
data in `codeinst.inferred` (together /w overloading `inlining_policy`).
After #52233, when such external absint uses
`InternalCodeCache`, this data is passed to `jl_ir_flag_inferred`,
leading to segfaults in assertion builds.

This commit resolves the issue by omitting `jl_ir_flag_inferred` checks
when the `cache_owner` is external. Nonetheless, a better resolution
might be necessary. It suggests that the current design of `code_owner`
and `InternalCodeCache` for the external cache system is somewhat
flawed. A conceivable approach could involve:
- Adding a layer similar to `inlining_policy` in
`CC.get(::WorldView{InternalCodeCache})` to enable safe redirection of
custom data to the native interpreter's implementation.
- Prohibiting custom data in the `inferred` field and directing such
data to be kept in `analysis_results`.

(cherry picked from commit 93876c9)
  • Loading branch information
aviatesk authored and KristofferC committed Feb 26, 2024
1 parent 7dad444 commit 913f7d5
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 9 deletions.
9 changes: 3 additions & 6 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,17 @@ end
function NativeInterpreter(world::UInt = get_world_counter();
inf_params::InferenceParams = InferenceParams(),
opt_params::OptimizationParams = OptimizationParams())
curr_max_world = get_world_counter()
# Sometimes the caller is lazy and passes typemax(UInt).
# we cap it to the current world age for correctness
if world == typemax(UInt)
world = get_world_counter()
world = curr_max_world
end

# If they didn't pass typemax(UInt) but passed something more subtly
# incorrect, fail out loudly.
@assert world <= get_world_counter()

@assert world <= curr_max_world
method_table = CachedMethodTable(InternalMethodTable(world))

inf_cache = Vector{InferenceResult}() # Initially empty cache

return NativeInterpreter(world, method_table, inf_cache, inf_params, opt_params)
end

Expand Down
7 changes: 5 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,11 @@ STATIC_INLINE jl_value_t *_jl_rettype_inferred(jl_value_t *owner, jl_method_inst
if (jl_atomic_load_relaxed(&codeinst->min_world) <= min_world &&
max_world <= jl_atomic_load_relaxed(&codeinst->max_world) &&
jl_egal(codeinst->owner, owner)) {
jl_value_t *code = jl_atomic_load_relaxed(&codeinst->inferred);
if (code && (code == jl_nothing || jl_ir_flag_inferred(code)))
jl_value_t *inferred = jl_atomic_load_relaxed(&codeinst->inferred);
if (inferred && ((inferred == jl_nothing) || (
// allow whatever code instance external abstract interpreter produced
// since `jl_ir_flag_inferred` is specific to the native interpreter
codeinst->owner != jl_nothing || jl_ir_flag_inferred(inferred))))
return (jl_value_t*)codeinst;
}
codeinst = jl_atomic_load_relaxed(&codeinst->next);
Expand Down
33 changes: 32 additions & 1 deletion test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ end
Core.eval(Core.Compiler, quote f(;a=1) = a end)
@test_throws MethodError Core.Compiler.f(;b=2)


# Custom lookup function
# ======================

Expand Down Expand Up @@ -469,3 +468,35 @@ let # generate cache
@test occursin("j_sin_", s)
@test !occursin("j_cos_", s)
end

# custom inferred data
# ====================

@newinterp CustomDataInterp
struct CustomDataInterpToken end
CC.cache_owner(::CustomDataInterp) = CustomDataInterpToken()
struct CustomData
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::CustomDataInterp,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
return CustomData(inferred_result)
end
function CC.inlining_policy(interp::CustomDataInterp, @nospecialize(src),
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
if src isa CustomData
src = src.inferred
end
return @invoke CC.inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
let src = code_typed((Int,); interp=CustomDataInterp()) do x
return sin(x) + cos(x)
end |> only |> first
@test count(isinvoke(:sin), src.code) == 1
@test count(isinvoke(:cos), src.code) == 1
@test count(isinvoke(:+), src.code) == 0
end
76 changes: 76 additions & 0 deletions test/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,82 @@ let newinterp_path = abspath("compiler/newinterp.jl")
@test found
end
end

write(joinpath(load_path, "CustomAbstractInterpreterCaching2.jl"), :(module CustomAbstractInterpreterCaching2
import SimpleModule: basic_caller, basic_callee

module Custom
const CC = Core.Compiler
include("$($newinterp_path)")
@newinterp PrecompileInterpreter
struct CustomData
inferred
CustomData(@nospecialize inferred) = new(inferred)
end
function CC.transform_result_for_cache(interp::PrecompileInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
inferred_result = @invoke CC.transform_result_for_cache(interp::CC.AbstractInterpreter,
mi::Core.MethodInstance, valid_worlds::CC.WorldRange, result::CC.InferenceResult)
return CustomData(inferred_result)
end
function CC.inlining_policy(interp::PrecompileInterpreter, @nospecialize(src),
@nospecialize(info::CC.CallInfo), stmt_flag::UInt32)
if src isa CustomData
src = src.inferred
end
return @invoke CC.inlining_policy(interp::CC.AbstractInterpreter, src::Any,
info::CC.CallInfo, stmt_flag::UInt32)
end
end

Base.return_types((Float64,)) do x
basic_caller(x)
end
Base.return_types((Float64,); interp=Custom.PrecompileInterpreter()) do x
basic_caller(x)
end
Base.return_types((Vector{Float64},)) do x
sum(x)
end
Base.return_types((Vector{Float64},); interp=Custom.PrecompileInterpreter()) do x
sum(x)
end
end) |> string)
Base.compilecache(Base.PkgId("CustomAbstractInterpreterCaching2"))
@eval let
using CustomAbstractInterpreterCaching2
cache_owner = Core.Compiler.cache_owner(
CustomAbstractInterpreterCaching2.Custom.PrecompileInterpreter())
let m = only(methods(CustomAbstractInterpreterCaching.basic_callee))
mi = only(Base.specializations(m))
ci = mi.cache
@test isdefined(ci, :next)
@test ci.owner === nothing
@test ci.max_world == typemax(UInt)
ci = ci.next
@test !isdefined(ci, :next)
@test ci.owner === cache_owner
@test ci.max_world == typemax(UInt)
end
let m = only(methods(sum, (Vector{Float64},)))
found = false
for mi = Base.specializations(m)
if mi isa Core.MethodInstance && mi.specTypes == Tuple{typeof(sum),Vector{Float64}}
ci = mi.cache
@test isdefined(ci, :next)
@test ci.owner === cache_owner
@test ci.max_world == typemax(UInt)
ci = ci.next
@test !isdefined(ci, :next)
@test ci.owner === nothing
@test ci.max_world == typemax(UInt)
found = true
break
end
end
@test found
end
end
end
end

Expand Down

0 comments on commit 913f7d5

Please sign in to comment.