Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] add world age bounds to CodeInfo #27073

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ mutable struct InferenceState

# info on the state of inference and the linfo
src::CodeInfo
min_valid::UInt
max_valid::UInt
nargs::Int
stmt_types::Vector{Any}
stmt_edges::Vector{Any}
Expand Down Expand Up @@ -129,17 +127,17 @@ mutable struct InferenceState
inmodule = linfo.def::Module
end

# TODO: we should be setting this correctly somewhere else
if cached && !toplevel
min_valid = min_world(linfo.def)
max_valid = max_world(linfo.def)
set_min_world!(src, min_world(linfo.def))
set_max_world!(src, max_world(linfo.def))
else
min_valid = typemax(UInt)
max_valid = typemin(UInt)
set_min_world!(src, typemax(UInt))
set_max_world!(src, typemin(UInt))
end
frame = new(
params, result, linfo,
sp, inmodule, 0,
src, min_valid, max_valid,
sp, inmodule, 0, src,
nargs, s_types, s_edges,
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
Expand Down Expand Up @@ -170,16 +168,16 @@ _topmod(sv::InferenceState) = _topmod(sv.mod)

# work towards converging the valid age range for sv
function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::InferenceState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
set_min_world!(sv.src, max(min_world(sv.src), min_valid))
set_max_world!(sv.src, min(max_world(sv.src), max_valid))
@assert(!isa(sv.linfo.def, Method) ||
!sv.cached ||
sv.min_valid <= sv.params.world <= sv.max_valid,
min_world(sv.src) <= sv.params.world <= max_world(sv.src),
"invalid age range update")
nothing
end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(edge.min_valid, edge.max_valid, sv)
update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(min_world(edge.src), max_world(edge.src), sv)
update_valid_age!(li::MethodInstance, sv::InferenceState) = update_valid_age!(min_world(li), max_world(li), sv)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
Expand Down
25 changes: 10 additions & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ mutable struct OptimizationState
mod::Module
nargs::Int
next_label::Int # index of the current highest label for this function
min_valid::UInt
max_valid::UInt
params::Params
function OptimizationState(frame::InferenceState)
s_edges = frame.stmt_edges[1]
Expand All @@ -26,8 +24,7 @@ mutable struct OptimizationState
return new(frame.linfo, frame.result.vargs,
s_edges::Vector{Any},
src, frame.mod, frame.nargs,
next_label, frame.min_valid, frame.max_valid,
frame.params)
next_label, frame.params)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo,
params::Params)
Expand Down Expand Up @@ -57,9 +54,7 @@ mutable struct OptimizationState
return new(linfo, result_vargs,
s_edges::Vector{Any},
src, inmodule, nargs,
next_label,
min_world(linfo), max_world(linfo),
params)
next_label, params)
end
end

Expand All @@ -82,11 +77,11 @@ function newvar!(sv::OptimizationState, @nospecialize(typ))
end

function update_valid_age!(min_valid::UInt, max_valid::UInt, sv::OptimizationState)
sv.min_valid = max(sv.min_valid, min_valid)
sv.max_valid = min(sv.max_valid, max_valid)
set_min_world!(sv.src, max(min_world(sv.src), min_valid))
set_max_world!(sv.src, min(max_world(sv.src), max_valid))
@assert(!isa(sv.linfo.def, Method) ||
(sv.min_valid == typemax(UInt) && sv.max_valid == typemin(UInt)) ||
sv.min_valid <= sv.params.world <= sv.max_valid,
(min_world(sv.src) == typemax(UInt) && max_world(sv.src) == typemin(UInt)) ||
min_world(sv.src) <= sv.params.world <= max_world(sv.src),
"invalid age range update")
nothing
end
Expand Down Expand Up @@ -323,8 +318,8 @@ function optimize(me::InferenceState)
reindex_labels!(opt)
end
end
me.min_valid = opt.min_valid
me.max_valid = opt.max_valid
set_min_world!(me.src, min_world(opt.src))
set_max_world!(me.src, max_world(opt.src))
end

# convert all type information into the form consumed by the code-generator
Expand Down Expand Up @@ -412,8 +407,8 @@ function finish(me::InferenceState)
if me.cached
toplevel = !isa(me.linfo.def, Method)
if !toplevel
min_valid = me.min_valid
max_valid = me.max_valid
min_valid = min_world(me.src)
max_valid = max_world(me.src)
else
min_valid = UInt(0)
max_valid = UInt(0)
Expand Down
22 changes: 12 additions & 10 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const COMPILER_TEMP_SYM = Symbol("#temp#")
# add the real backedges
function finalize_backedges(frame::InferenceState)
toplevel = !isa(frame.linfo.def, Method)
if !toplevel && (frame.cached || frame.parent !== nothing) && frame.max_valid == typemax(UInt)
if !toplevel && (frame.cached || frame.parent !== nothing) && max_world(frame.src) == typemax(UInt)
caller = frame.linfo
for edges in frame.stmt_edges
i = 1
Expand Down Expand Up @@ -173,6 +173,8 @@ function typeinf_code(linfo::MethodInstance, optimize::Bool, cached::Bool,
tree.slotflags = fill(0x00, Int(method.nargs))
tree.slottypes = nothing
tree.ssavaluetypes = 0
set_min_world!(tree, min_world(linfo))
set_max_world!(tree, max_world(linfo))
tree.inferred = true
tree.ssaflags = UInt8[]
tree.pure = true
Expand Down Expand Up @@ -423,11 +425,11 @@ function typeinf(frame::InferenceState)
typeinf_work(caller)
no_active_ips_in_callers = false
end
if caller.min_valid < frame.min_valid
caller.min_valid = frame.min_valid
if min_world(caller.src) < min_world(frame.src)
set_min_world!(caller.src, min_world(frame.src))
end
if caller.max_valid > frame.max_valid
caller.max_valid = frame.max_valid
if max_world(caller.src) > max_world(frame.src)
set_max_world!(caller.src, max_world(frame.src))
end
end
end
Expand All @@ -448,16 +450,16 @@ function typeinf(frame::InferenceState)
# complete the computation of the src optimizations
for caller in frame.callers_in_cycle
optimize(caller)
if frame.min_valid < caller.min_valid
frame.min_valid = caller.min_valid
if min_world(frame.src) < min_world(caller.src)
set_min_world!(frame.src, min_world(caller.src))
end
if frame.max_valid > caller.max_valid
frame.max_valid = caller.max_valid
if max_world(frame.src) > max_world(caller.src)
set_max_world!(frame.src, max_world(caller.src))
end
end
# update and store in the global cache
for caller in frame.callers_in_cycle
caller.min_valid = frame.min_valid
set_min_world!(caller.src, min_world(frame.src))
end
for caller in frame.callers_in_cycle
finish(caller)
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,19 @@ function occurs_more(@nospecialize(e), pred, n)
return 0
end

###########################
# MethodInstance/CodeInfo #
###########################
##################################
# Method/MethodInstance/CodeInfo #
##################################

min_world(m::Method) = reinterpret(UInt, m.min_world)
max_world(m::Method) = typemax(UInt)
min_world(m::MethodInstance) = reinterpret(UInt, m.min_world)
max_world(m::MethodInstance) = reinterpret(UInt, m.max_world)
min_world(c::CodeInfo) = reinterpret(UInt, c.min_world)
max_world(c::CodeInfo) = reinterpret(UInt, c.max_world)

set_min_world!(c::CodeInfo, w::UInt) = (sw = reinterpret(Int, w); c.min_world = sw; sw)
set_max_world!(c::CodeInfo, w::UInt) = (sw = reinterpret(Int, w); c.max_world = sw; sw)

function invoke_api(li::MethodInstance)
return ccall(:jl_invoke_api, Cint, (Any,), li)
Expand Down
2 changes: 1 addition & 1 deletion base/errorshow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=()
end
end
end
if ex.world < min_world(method)
if ex.world < Core.Compiler.min_world(method)
print(iob, " (method too new to be called from this world context.)")
end
# TODO: indicate if it's in the wrong world
Expand Down
5 changes: 0 additions & 5 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1058,11 +1058,6 @@ has_bottom_parameter(t::Union) = has_bottom_parameter(t.a) & has_bottom_paramete
has_bottom_parameter(t::TypeVar) = t.ub == Bottom || has_bottom_parameter(t.ub)
has_bottom_parameter(::Any) = false

min_world(m::Method) = reinterpret(UInt, m.min_world)
max_world(m::Method) = typemax(UInt)
min_world(m::Core.MethodInstance) = reinterpret(UInt, m.min_world)
max_world(m::Core.MethodInstance) = reinterpret(UInt, m.max_world)

"""
propertynames(x, private=false)

Expand Down
10 changes: 8 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2287,6 +2287,9 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
| (code->pure << 0);
write_uint8(s.s, flags);

write_int32(s.s, code->min_world);
write_int32(s.s, code->max_world);

size_t nsyms = jl_array_len(code->slotnames);
assert(nsyms >= m->nargs && nsyms < INT32_MAX); // required by generated functions
write_int32(s.s, nsyms);
Expand All @@ -2299,7 +2302,7 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 7; i++) {
int copy = (i != 2); // don't copy contents of method_for_inference_limit_heuristics field
jl_serialize_value_(&s, jl_get_nth_field((jl_value_t*)code, i), copy);
}
Expand Down Expand Up @@ -2350,6 +2353,9 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
code->propagate_inbounds = !!(flags & (1 << 1));
code->pure = !!(flags & (1 << 0));

code->min_world = read_int32(s.s);
code->max_world = read_int32(s.s);

size_t nslots = read_int32(&src);
jl_array_t *syms = jl_alloc_vec_any(nslots);
code->slotnames = syms;
Expand All @@ -2362,7 +2368,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_array_t *data)
}

size_t nf = jl_datatype_nfields(jl_code_info_type);
for (i = 0; i < nf - 5; i++) {
for (i = 0; i < nf - 7; i++) {
assert(jl_field_isptr(jl_code_info_type, i));
jl_value_t **fld = (jl_value_t**)((char*)jl_data_ptr(code) + jl_field_offset(jl_code_info_type, i));
*fld = jl_deserialize_value(&s, fld);
Expand Down
10 changes: 7 additions & 3 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2023,7 +2023,7 @@ void jl_init_types(void)
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(13,
jl_perm_symsvec(15,
"code",
"codelocs",
"method_for_inference_limit_heuristics",
Expand All @@ -2033,11 +2033,13 @@ void jl_init_types(void)
"ssaflags",
"slotflags",
"slotnames",
"min_world",
"max_world",
"inferred",
"inlineable",
"propagate_inbounds",
"pure"),
jl_svec(13,
jl_svec(15,
jl_array_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2050,11 +2052,13 @@ void jl_init_types(void)
// If you change them, you'll have to adjust the
// serializer
jl_array_any_type,
jl_long_type,
jl_long_type,
jl_bool_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 13);
0, 1, 15);

jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ typedef struct _jl_code_info_t {
// 7 = has out-of-band info
jl_array_t *slotflags; // local var bit flags
jl_array_t *slotnames; // names of local variables
size_t min_world;
size_t max_world;
uint8_t inferred;
uint8_t inlineable;
uint8_t propagate_inbounds;
Expand Down
4 changes: 4 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->codelocs = jl_nothing;
src->linetable = jl_nothing;
src->ssaflags = NULL;
src->min_world = 0;
src->max_world = 0;
src->inferred = 0;
src->pure = 0;
src->inlineable = 0;
Expand Down Expand Up @@ -520,6 +522,8 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
jl_array_ptr_set(copy, i, st);
}
src = jl_copy_code_info(src);
src->min_world = m->min_world;
src->max_world = ~(size_t)0;
src->code = copy;
jl_gc_wb(src, copy);
if (gen_only)
Expand Down
4 changes: 2 additions & 2 deletions test/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1303,8 +1303,8 @@ let linfo = get_linfo(Base.convert, Tuple{Type{Int64}, Int32}),
@test opt.src.ssavaluetypes isa Vector{Any}
@test !opt.src.inferred
@test opt.mod === Base
@test opt.max_valid === typemax(UInt)
@test opt.min_valid === Core.Compiler.min_world(opt.linfo) > 2
@test Core.Compiler.max_world(opt.src) === typemax(UInt)
@test Core.Compiler.min_world(opt.src) === Core.Compiler.min_world(opt.linfo) > 2
@test opt.nargs == 3
end

Expand Down