Skip to content

Commit

Permalink
Add roots field to CodeInstance
Browse files Browse the repository at this point in the history
This is designed to eliminate confusion in the serialization by ensuring
that offsets are relative to a "private" roots table. This may allow
more extensive caching of inference results, because it should eliminate
root-indexing conflicts between different instances of the same method.
  • Loading branch information
timholy committed Jul 26, 2019
1 parent 9daaed6 commit a3a71e3
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 68 deletions.
14 changes: 7 additions & 7 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ function analyze_method!(idx::Int, sig::Signature, @nospecialize(metharg), meths
return spec_lambda(atype_unlimited, sv, invoke_data)
end

isconst, src = find_inferred(mi, atypes, sv, stmttyp)
isconst, src, roots = find_inferred(mi, atypes, sv, stmttyp)
if isconst
add_backedge!(mi, sv)
return ConstantCase(src, method, Any[methsp...], metharg)
Expand All @@ -722,7 +722,7 @@ function analyze_method!(idx::Int, sig::Signature, @nospecialize(metharg), meths
add_backedge!(mi, sv)

if !isa(src, CodeInfo)
src = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any), method, C_NULL, src::Vector{UInt8})::CodeInfo
src = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any, Any), method, C_NULL, src::Vector{UInt8}, roots)::CodeInfo
end

@timeit "inline IR inflation" begin
Expand Down Expand Up @@ -1294,10 +1294,10 @@ function find_inferred(mi::MethodInstance, @nospecialize(atypes), sv::Optimizati
if isa(inf_result, InferenceResult)
let inferred_src = inf_result.src
if isa(inferred_src, CodeInfo)
return svec(false, inferred_src)
return svec(false, inferred_src, nothing)
end
if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val)
return svec(true, quoted(inferred_src.val),)
return svec(true, quoted(inferred_src.val), nothing)
end
end
end
Expand All @@ -1306,9 +1306,9 @@ function find_inferred(mi::MethodInstance, @nospecialize(atypes), sv::Optimizati
if linfo isa CodeInstance
if invoke_api(linfo) == 2
# in this case function can be inlined to a constant
return svec(true, quoted(linfo.rettype_const))
return svec(true, quoted(linfo.rettype_const), linfo.localroots)
end
return svec(false, linfo.inferred)
return svec(false, linfo.inferred, linfo.localroots)
end
return svec(false, nothing)
return svec(false, nothing, nothing)
end
19 changes: 13 additions & 6 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ end
# update the MethodInstance and notify the edges
function cache_result(result::InferenceResult, min_valid::UInt, max_valid::UInt)
def = result.linfo.def
toplevel = !isa(result.linfo.def, Method)
toplevel = !isa(def, Method)
if toplevel
min_valid = UInt(0)
max_valid = UInt(0)
Expand Down Expand Up @@ -129,12 +129,19 @@ function cache_result(result::InferenceResult, min_valid::UInt, max_valid::UInt)
end
end
end
if !isa(inferred_result, Union{CodeInfo, Vector{UInt8}})
inferred_result = nothing
if isa(inferred_result, CodeInfo)
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, inferred_result, C_NULL,
const_flags, min_valid, max_valid)
elseif isa(inferred_result, SimpleVector)
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, inferred_result[1], inferred_result[2],
const_flags, min_valid, max_valid)
else
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, nothing, C_NULL,
const_flags, min_valid, max_valid)
end
ccall(:jl_set_method_inferred, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
result.linfo, widenconst(result.result), rettype_const, inferred_result,
const_flags, min_valid, max_valid)
end
result.linfo.inInference = false
nothing
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function retrieve_code_info(linfo::MethodInstance)
if c === nothing && isdefined(m, :source)
src = m.source
if isa(src, Array{UInt8,1})
c = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any), m, C_NULL, src)
c = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any, Any), m, C_NULL, src, m.roots)
else
c = copy(src::CodeInfo)
end
Expand Down
4 changes: 2 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,8 @@ uncompressed_ast(m::Method) = isdefined(m, :source) ? _uncompressed_ast(m, m.sou
isdefined(m, :generator) ? error("Method is @generated; try `code_lowered` instead.") :
error("Code for this Method is not available.")
_uncompressed_ast(m::Method, s::CodeInfo) = copy(s)
_uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any), m, C_NULL, s)::CodeInfo
_uncompressed_ast(ci::Core.CodeInstance, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any, Any), ci.def.def::Method, ci, s)::CodeInfo
_uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Ptr{Cvoid}, Any, Any), m, C_NULL, s, m.roots)::CodeInfo
_uncompressed_ast(ci::Core.CodeInstance, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any, Any, Any), ci.def.def::Method, ci, s, ci.localroots)::CodeInfo

function method_instances(@nospecialize(f), @nospecialize(t), world::UInt = typemax(UInt))
tt = signature_type(f, t)
Expand Down
16 changes: 10 additions & 6 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ jl_code_instance_t *jl_compile_linfo(jl_method_instance_t *mi, jl_code_info_t *s
}
// found inferred code, prep it for codegen
if ((jl_value_t*)src != jl_nothing)
src = jl_uncompress_ast(mi->def.method, codeinst, (jl_array_t*)src);
src = jl_uncompress_ast(mi->def.method, codeinst, (jl_array_t*)src, (jl_array_t*)(codeinst->localroots));
if (!jl_is_code_info(src)) {
src = jl_type_infer(mi, world, 0);
if (!src)
Expand Down Expand Up @@ -1157,7 +1157,7 @@ jl_code_instance_t *jl_compile_linfo(jl_method_instance_t *mi, jl_code_info_t *s
}
}
assert((jl_value_t*)src != jl_nothing);
src = jl_uncompress_ast(mi->def.method, NULL, (jl_array_t*)src);
src = jl_uncompress_ast(mi->def.method, NULL, (jl_array_t*)src, mi->def.method->roots);
}
assert(jl_is_code_info(src));

Expand All @@ -1175,6 +1175,7 @@ jl_code_instance_t *jl_compile_linfo(jl_method_instance_t *mi, jl_code_info_t *s
uncached->functionObjectsDecls.functionObject = NULL;
uncached->functionObjectsDecls.specFunctionObject = NULL;
uncached->inferred = jl_nothing;
uncached->localroots = jl_alloc_vec_any(0);
if (uncached->invoke != jl_fptr_const_return)
uncached->invoke = NULL;
uncached->specptr.fptr = NULL;
Expand Down Expand Up @@ -1240,8 +1241,11 @@ jl_code_instance_t *jl_compile_linfo(jl_method_instance_t *mi, jl_code_info_t *s
jl_options.debug_level > 1) {
// update the stored code
if (codeinst->inferred != (jl_value_t*)src) {
if (jl_is_method(mi->def.method))
src = (jl_code_info_t*)jl_compress_ast(mi->def.method, src);
if (jl_is_method(mi->def.method)) {
jl_svec_t *srcroots = jl_compress_ast(mi->def.method, src);
src = (jl_code_info_t*) jl_svecref(srcroots, 0);
codeinst->localroots = (jl_array_t*) jl_svecref(srcroots, 1);
}
codeinst->inferred = (jl_value_t*)src;
jl_gc_wb(codeinst, src);
}
Expand Down Expand Up @@ -1504,7 +1508,7 @@ void *jl_get_llvmf_defn(jl_method_instance_t *mi, size_t world, bool getwrapper,
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
src = (jl_code_info_t*)codeinst->inferred;
if ((jl_value_t*)src != jl_nothing && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ast(mi->def.method, codeinst, (jl_array_t*)src);
src = jl_uncompress_ast(mi->def.method, codeinst, (jl_array_t*)src, (jl_array_t*)(codeinst->localroots));
jlrettype = codeinst->rettype;
}
if (!src || (jl_value_t*)src == jl_nothing) {
Expand All @@ -1514,7 +1518,7 @@ void *jl_get_llvmf_defn(jl_method_instance_t *mi, size_t world, bool getwrapper,
if (!src && jl_is_method(mi->def.method)) {
src = mi->def.method->generator ? jl_code_for_staged(mi) : (jl_code_info_t*)mi->def.method->source;
if (src && !jl_is_code_info(src) && jl_is_method(mi->def.method))
src = jl_uncompress_ast(mi->def.method, NULL, (jl_array_t*)src);
src = jl_uncompress_ast(mi->def.method, NULL, (jl_array_t*)src, mi->def.method->roots);
}
}
if (src == NULL || (jl_value_t*)src == jl_nothing || !jl_is_code_info(src))
Expand Down
50 changes: 26 additions & 24 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ typedef struct {
ios_t *s;
DUMP_MODES mode;
// method we're compressing for in MODE_IR
jl_method_t *method;
jl_module_t *module;
jl_array_t *roots;
jl_ptls_t ptls;
jl_array_t *loaded_modules_array;
} jl_serializer_state;
Expand Down Expand Up @@ -484,7 +485,7 @@ static int is_ast_node(jl_value_t *v)

static int literal_val_id(jl_serializer_state *s, jl_value_t *v) JL_GC_DISABLED
{
jl_array_t *rs = s->method->roots;
jl_array_t *rs = s->roots;
int i, l = jl_array_len(rs);
if (jl_is_symbol(v) || jl_is_concrete_type(v)) {
for (i = 0; i < l; i++) {
Expand Down Expand Up @@ -535,7 +536,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
}

if (s->mode == MODE_IR) {
if (v == (jl_value_t*)s->method->module) {
if (v == (jl_value_t*)s->module) {
write_uint8(s->s, TAG_NEARBYMODULE);
return;
}
Expand Down Expand Up @@ -640,7 +641,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
ios_write(s->s, jl_symbol_name((jl_sym_t*)v), l);
}
else if (jl_is_globalref(v)) {
if (s->mode == MODE_IR && jl_globalref_mod(v) == s->method->module) {
if (s->mode == MODE_IR && jl_globalref_mod(v) == s->module) {
write_uint8(s->s, TAG_NEARBYGLOBAL);
jl_serialize_value(s, jl_globalref_name(v));
}
Expand Down Expand Up @@ -1387,7 +1388,7 @@ static jl_value_t *jl_deserialize_datatype(jl_serializer_state *s, int pos, jl_v
assert(0 && "corrupt deserialization state");
abort();
}
assert(s->method == NULL && s->mode != MODE_IR && "no new data-types expected during MODE_IR");
assert(s->roots == NULL && s->mode != MODE_IR && "no new data-types expected during MODE_IR");
assert(pos == backref_list.len - 1 && "nothing should have been deserialized since assigning pos");
backref_list.items[pos] = dt;
dt->size = size;
Expand Down Expand Up @@ -1763,6 +1764,8 @@ static jl_value_t *jl_deserialize_value_code_instance(jl_serializer_state *s, jl
jl_gc_wb(codeinst, codeinst->def);
codeinst->inferred = jl_deserialize_value(s, &codeinst->inferred);
jl_gc_wb(codeinst, codeinst->inferred);
codeinst->localroots = (jl_array_t*) jl_deserialize_value(s, (jl_value_t**) &(codeinst->localroots));
jl_gc_wb(codeinst, codeinst->localroots);
codeinst->rettype_const = jl_deserialize_value(s, &codeinst->rettype_const);
if (codeinst->rettype_const)
jl_gc_wb(codeinst, codeinst->rettype_const);
Expand Down Expand Up @@ -2023,7 +2026,7 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
tag = read_uint8(s->s);
return deser_tag[tag];
case TAG_BACKREF: JL_FALLTHROUGH; case TAG_SHORT_BACKREF:
assert(s->method == NULL && s->mode != MODE_IR);
assert(s->roots == NULL && s->mode != MODE_IR);
uintptr_t offs = (tag == TAG_BACKREF) ? read_int32(s->s) : read_uint16(s->s);
int isflagref = 0;
isflagref = !!(offs & 1);
Expand All @@ -2042,9 +2045,9 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
}
return (jl_value_t*)bp;
case TAG_METHODROOT:
return jl_array_ptr_ref(s->method->roots, read_uint8(s->s));
return jl_array_ptr_ref(s->roots, read_uint8(s->s));
case TAG_LONG_METHODROOT:
return jl_array_ptr_ref(s->method->roots, read_uint16(s->s));
return jl_array_ptr_ref(s->roots, read_uint16(s->s));
case TAG_SVEC: JL_FALLTHROUGH; case TAG_LONG_SVEC:
return jl_deserialize_value_svec(s, tag);
case TAG_COMMONSYM:
Expand Down Expand Up @@ -2152,12 +2155,12 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
case TAG_UINT8:
return jl_box_uint8(read_uint8(s->s));
case TAG_NEARBYGLOBAL:
assert(s->method != NULL);
assert(s->module != NULL);
v = jl_deserialize_value(s, NULL);
return jl_module_globalref(s->method->module, (jl_sym_t*)v);
return jl_module_globalref(s->module, (jl_sym_t*)v);
case TAG_NEARBYMODULE:
assert(s->method != NULL);
return (jl_value_t*)s->method->module;
assert(s->module != NULL);
return (jl_value_t*)s->module;
case TAG_GLOBALREF:
return jl_deserialize_value_globalref(s);
case TAG_SINGLETON:
Expand Down Expand Up @@ -2479,7 +2482,7 @@ JL_DLLEXPORT void jl_init_restored_modules(jl_array_t *init_order)

// --- entry points ---

JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
JL_DLLEXPORT jl_svec_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
{
JL_TIMING(AST_COMPRESS);
JL_LOCK(&m->writelock); // protect the roots array (Might GC)
Expand All @@ -2490,13 +2493,11 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
int en = jl_gc_enable(0); // Might GC
size_t i;

if (m->roots == NULL) {
m->roots = jl_alloc_vec_any(0);
jl_gc_wb(m, m->roots);
}
jl_array_t *localroots = jl_alloc_vec_any(0);
jl_serializer_state s = {
&dest, MODE_IR,
m,
m->module,
localroots,
jl_get_ptls_states(),
NULL
};
Expand Down Expand Up @@ -2552,26 +2553,24 @@ JL_DLLEXPORT jl_array_t *jl_compress_ast(jl_method_t *m, jl_code_info_t *code)
}

ios_flush(s.s);
jl_array_t *v = jl_take_buffer(&dest);
jl_svec_t *v = jl_svec2(jl_take_buffer(&dest), localroots);
ios_close(s.s);
if (jl_array_len(m->roots) == 0) {
m->roots = NULL;
}
JL_GC_PUSH1(&v);
jl_gc_enable(en);
JL_UNLOCK(&m->writelock); // Might GC
JL_GC_POP();
return v;
}

JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_code_instance_t *metadata, jl_array_t *data)
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_code_instance_t *metadata, jl_array_t *data, jl_array_t *localroots)
{
if (jl_is_code_info(data))
return (jl_code_info_t*)data;
JL_TIMING(AST_UNCOMPRESS);
JL_LOCK(&m->writelock); // protect the roots array (Might GC)
assert(jl_is_method(m));
assert(jl_typeis(data, jl_array_uint8_type));
assert(localroots);
size_t i;
ios_t src;
ios_mem(&src, 0);
Expand All @@ -2580,7 +2579,8 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ast(jl_method_t *m, jl_code_instance_
int en = jl_gc_enable(0); // Might GC
jl_serializer_state s = {
&src, MODE_IR,
m,
m->module,
localroots,
jl_get_ptls_states(),
NULL
};
Expand Down Expand Up @@ -2817,6 +2817,7 @@ JL_DLLEXPORT int jl_save_incremental(const char *fname, jl_array_t *worklist)
jl_serializer_state s = {
&f, MODE_MODULE,
NULL,
NULL,
jl_get_ptls_states(),
mod_array
};
Expand Down Expand Up @@ -3198,6 +3199,7 @@ static jl_value_t *_jl_restore_incremental(ios_t *f, jl_array_t *mod_array)
jl_serializer_state s = {
f, MODE_MODULE,
NULL,
NULL,
ptls,
mod_array
};
Expand Down
13 changes: 8 additions & 5 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t*);
JL_DLLEXPORT jl_code_instance_t* jl_set_method_inferred(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
jl_array_t *localroots,
int32_t const_flags, size_t min_world, size_t max_world);

jl_datatype_t *jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_args_t fptr) JL_GC_DISABLED
Expand All @@ -155,7 +156,7 @@ jl_datatype_t *jl_mk_builtin_func(jl_datatype_t *dt, const char *name, jl_fptr_a
m->unspecialized = mi;
jl_gc_wb(m, mi);

jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, jl_nothing, jl_nothing,
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, jl_nothing, jl_nothing, NULL,
0, 1, ~(size_t)0);
codeinst->specptr.fptr1 = fptr;
codeinst->invoke = jl_fptr_args;
Expand Down Expand Up @@ -266,13 +267,14 @@ JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(
codeinst = codeinst->next;
}
return jl_set_method_inferred(
mi, rettype, NULL, NULL,
mi, rettype, NULL, NULL, NULL,
0, min_world, max_world);
}

JL_DLLEXPORT jl_code_instance_t *jl_set_method_inferred(
jl_method_instance_t *mi, jl_value_t *rettype,
jl_value_t *inferred_const, jl_value_t *inferred,
jl_array_t *localroots,
int32_t const_flags, size_t min_world, size_t max_world
/*, jl_array_t *edges, int absolute_max*/)
{
Expand All @@ -288,6 +290,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_set_method_inferred(
codeinst->functionObjectsDecls.specFunctionObject = NULL;
codeinst->rettype = rettype;
codeinst->inferred = inferred;
codeinst->localroots = localroots;
//codeinst->edges = NULL;
if ((const_flags & 2) == 0)
inferred_const = NULL;
Expand Down Expand Up @@ -1738,7 +1741,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
if (jl_is_method(def) && def->unspecialized) {
jl_code_instance_t *unspec = def->unspecialized->cache;
if (unspec && unspec->invoke != NULL) {
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL,
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL, NULL,
0, 1, ~(size_t)0);
codeinst->specptr = unspec->specptr;
codeinst->rettype_const = unspec->rettype_const;
Expand All @@ -1748,7 +1751,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
}
jl_code_info_t *src = jl_code_for_interpreter(mi);
if (!jl_code_requires_compiler(src)) {
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL,
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL, NULL,
0, 1, ~(size_t)0);
jl_atomic_store_release(&codeinst->invoke, jl_fptr_interpret_call);
return codeinst;
Expand Down Expand Up @@ -1786,7 +1789,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
ucache->invoke != jl_fptr_interpret_call) {
return ucache;
}
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL,
jl_code_instance_t *codeinst = jl_set_method_inferred(mi, (jl_value_t*)jl_any_type, NULL, NULL, NULL,
0, 1, ~(size_t)0);
codeinst->specptr = ucache->specptr;
codeinst->rettype_const = ucache->rettype_const;
Expand Down
Loading

0 comments on commit a3a71e3

Please sign in to comment.