Skip to content

Commit

Permalink
Merge pull request #22744 from JuliaLang/jb/isstaged
Browse files Browse the repository at this point in the history
RFC: slightly refactor representation of generated methods
  • Loading branch information
JeffBezanson authored Jul 11, 2017
2 parents 267cc44 + 9a41c33 commit d91583d
Show file tree
Hide file tree
Showing 13 changed files with 57 additions and 48 deletions.
10 changes: 5 additions & 5 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ function InferenceState(linfo::MethodInstance,
# prepare an InferenceState object for inferring lambda
# create copies of the CodeInfo definition, and any fields that type-inference might modify
m = linfo.def::Method
if m.isstaged
if isdefined(m, :generator)
try
# user code might throw errors – ignore them
src = get_staged(linfo)
Expand Down Expand Up @@ -1996,7 +1996,7 @@ function pure_eval_call(f::ANY, argtypes::ANY, atype::ANY, sv::InferenceState)
meth = meth[1]::SimpleVector
method = meth[3]::Method
# TODO: check pure on the inferred thunk
if method.isstaged || !method.pure
if isdefined(method, :generator) || !method.pure
return false
end

Expand Down Expand Up @@ -2790,7 +2790,7 @@ function code_for_method(method::Method, atypes::ANY, sparams::SimpleVector, wor
if world < min_world(method)
return nothing
end
if method.isstaged && !isleaftype(atypes)
if isdefined(method, :generator) && !isleaftype(atypes)
# don't call staged functions on abstract types.
# (see issues #8504, #10230)
# we can't guarantee that their type behavior is monotonic.
Expand Down Expand Up @@ -4208,7 +4208,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end

# check whether call can be inlined to just a quoted constant value
if isa(f, widenconst(ft)) && !method.isstaged
if isa(f, widenconst(ft)) && !isdefined(method, :generator)
if f === return_type
if isconstType(e.typ)
return inline_as_constant(e.typ.parameters[1], argexprs, sv, invoke_data)
Expand Down Expand Up @@ -4250,7 +4250,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
# some gf have special tfunc, meaning they wouldn't have been inferred yet
# check the same conditions from abstract_call to detect this case
force_infer = false
if !method.isstaged
if !isdefined(method, :generator)
if method.module == _topmod(method.module) || (isdefined(Main, :Base) && method.module == Main.Base)
la = length(atypes)
if (la==3 && (method.name == :getindex || method.name == :next)) ||
Expand Down
4 changes: 2 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ function length(mt::MethodTable)
end
isempty(mt::MethodTable) = (mt.defs === nothing)

uncompressed_ast(m::Method) = uncompressed_ast(m, m.source)
uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m,:source) ? m.source : m.generator.inferred)
uncompressed_ast(m::Method, s::CodeInfo) = s
uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo

Expand Down Expand Up @@ -799,7 +799,7 @@ code_native(::IO, ::ANY, ::Symbol) = error("illegal code_native call") # resolve

# give a decent error message if we try to instantiate a staged function on non-leaf types
function func_for_method_checked(m::Method, types::ANY)
if m.isstaged && !isleaftype(types)
if isdefined(m,:generator) && !isdefined(m,:source) && !isleaftype(types)
error("cannot call @generated function `", m, "` ",
"with abstract argument types: ", types)
end
Expand Down
27 changes: 18 additions & 9 deletions base/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -398,8 +398,16 @@ function serialize(s::AbstractSerializer, meth::Method)
serialize(s, meth.ambig)
serialize(s, meth.nargs)
serialize(s, meth.isva)
serialize(s, meth.isstaged)
serialize(s, uncompressed_ast(meth, meth.source))
if isdefined(meth, :source)
serialize(s, uncompressed_ast(meth, meth.source))
else
serialize(s, nothing)
end
if isdefined(meth, :generator)
serialize(s, uncompressed_ast(meth, meth.generator.inferred))
else
serialize(s, nothing)
end
nothing
end

Expand Down Expand Up @@ -783,8 +791,8 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
ambig = deserialize(s)::Union{Array{Any,1}, Void}
nargs = deserialize(s)::Int32
isva = deserialize(s)::Bool
isstaged = deserialize(s)::Bool
template = deserialize(s)::CodeInfo
template = deserialize(s)
generator = deserialize(s)
if makenew
meth.module = mod
meth.name = name
Expand All @@ -793,16 +801,17 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
meth.sig = sig
meth.sparam_syms = sparam_syms
meth.ambig = ambig
meth.isstaged = isstaged
meth.nargs = nargs
meth.isva = isva
# TODO: compress template
meth.source = template
meth.pure = template.pure
if isstaged
if template !== nothing
meth.source = template
meth.pure = template.pure
end
if generator !== nothing
linfo = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ())
linfo.specTypes = Tuple
linfo.inferred = template
linfo.inferred = generator
linfo.def = meth
meth.generator = linfo
end
Expand Down
2 changes: 1 addition & 1 deletion base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ end
function show(io::IO, l::Core.MethodInstance)
def = l.def
if isa(def, Method)
if def.isstaged && l === def.generator
if isdefined(def, :generator) && l === def.generator
print(io, "MethodInstance generator for ")
show(io, def)
else
Expand Down
20 changes: 13 additions & 7 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ static jl_method_instance_t *jl_get_unspecialized(jl_method_instance_t *method)
{
// one unspecialized version of a function can be shared among all cached specializations
jl_method_t *def = method->def.method;
if (def->isstaged) {
if (def->source == NULL) {
return method;
}
if (def->unspecialized == NULL) {
Expand Down Expand Up @@ -1340,7 +1340,7 @@ jl_generic_fptr_t jl_generate_fptr(jl_method_instance_t *li, void *_F, size_t wo
}
jl_method_instance_t *unspec = NULL;
if (jl_is_method(li->def.method)) {
if (!li->def.method->isstaged && li->def.method->unspecialized) {
if (li->def.method->unspecialized) {
unspec = li->def.method->unspecialized;
}
if (!F || !jl_can_finalize_function(F)) {
Expand All @@ -1349,7 +1349,11 @@ jl_generic_fptr_t jl_generate_fptr(jl_method_instance_t *li, void *_F, size_t wo
// and return its fptr instead
if (!unspec)
unspec = jl_get_unspecialized(li); // get-or-create the unspecialized version to cache the result
jl_code_info_t *src = unspec->def.method->isstaged ? jl_code_for_staged(unspec) : (jl_code_info_t*)unspec->def.method->source;
jl_code_info_t *src = (jl_code_info_t*)unspec->def.method->source;
if (src == NULL) {
assert(unspec->def.method->generator);
src = jl_code_for_staged(unspec);
}
fptr.fptr = unspec->fptr;
fptr.jlcall_api = unspec->jlcall_api;
if (fptr.fptr && fptr.jlcall_api) {
Expand Down Expand Up @@ -1466,7 +1470,8 @@ void jl_extern_c(jl_function_t *f, jl_value_t *rt, jl_value_t *argt, char *name)
extern "C" JL_DLLEXPORT
void *jl_get_llvmf_defn(jl_method_instance_t *linfo, size_t world, bool getwrapper, bool optimize, const jl_cgparams_t params)
{
if (jl_is_method(linfo->def.method) && linfo->def.method->source == NULL) {
if (jl_is_method(linfo->def.method) && linfo->def.method->source == NULL &&
linfo->def.method->generator == NULL) {
// not a generic function
return NULL;
}
Expand All @@ -1476,7 +1481,7 @@ void *jl_get_llvmf_defn(jl_method_instance_t *linfo, size_t world, bool getwrapp
if (!src || (jl_value_t*)src == jl_nothing) {
src = jl_type_infer(&linfo, world, 0);
if (!src && jl_is_method(linfo->def.method))
src = linfo->def.method->isstaged ? jl_code_for_staged(linfo) : (jl_code_info_t*)linfo->def.method->source;
src = linfo->def.method->generator ? jl_code_for_staged(linfo) : (jl_code_info_t*)linfo->def.method->source;
}
if ((jl_value_t*)src == jl_nothing)
src = NULL;
Expand Down Expand Up @@ -1546,7 +1551,8 @@ void *jl_get_llvmf_defn(jl_method_instance_t *linfo, size_t world, bool getwrapp
extern "C" JL_DLLEXPORT
void *jl_get_llvmf_decl(jl_method_instance_t *linfo, size_t world, bool getwrapper, const jl_cgparams_t params)
{
if (jl_is_method(linfo->def.method) && linfo->def.method->source == NULL) {
if (jl_is_method(linfo->def.method) && linfo->def.method->source == NULL &&
linfo->def.method->generator == NULL) {
// not a generic function
return NULL;
}
Expand All @@ -1563,7 +1569,7 @@ void *jl_get_llvmf_decl(jl_method_instance_t *linfo, size_t world, bool getwrapp
jl_code_info_t *src = NULL;
src = jl_type_infer(&linfo, world, 0);
if (!src) {
src = linfo->def.method->isstaged ? jl_code_for_staged(linfo) : (jl_code_info_t*)linfo->def.method->source;
src = linfo->def.method->generator ? jl_code_for_staged(linfo) : (jl_code_info_t*)linfo->def.method->source;
}
decls = jl_compile_linfo(&linfo, src, world, &params);
linfo->functionObjectsDecls = decls;
Expand Down
2 changes: 0 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
union jl_typemap_t *tf = &m->specializations;
jl_serialize_value(s, tf->unknown);
jl_serialize_value(s, (jl_value_t*)m->name);
write_int8(s->s, m->isstaged);
jl_serialize_value(s, (jl_value_t*)m->file);
write_int32(s->s, m->line);
if (external_mt)
Expand Down Expand Up @@ -1396,7 +1395,6 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
jl_gc_wb(m, m->specializations.unknown);
m->name = (jl_sym_t*)jl_deserialize_value(s, NULL);
jl_gc_wb(m, m->name);
m->isstaged = read_int8(s->s);
m->file = (jl_sym_t*)jl_deserialize_value(s, NULL);
m->line = read_int32(s->s);
m->min_world = jl_world_counter;
Expand Down
9 changes: 4 additions & 5 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,6 @@ static void jl_cacheable_sig(
int *const need_guard_entries,
int *const makesimplesig)
{
int8_t isstaged = definition->isstaged;
assert(jl_is_tuple_type(type));
size_t i, np = jl_nparams(type);
for (i = 0; i < np; i++) {
Expand All @@ -593,7 +592,7 @@ static void jl_cacheable_sig(
continue;
}

if (isstaged) {
if (definition->generator) {
// staged functions can't be optimized
continue;
}
Expand Down Expand Up @@ -697,7 +696,7 @@ JL_DLLEXPORT int jl_is_cacheable_sig(
// compute whether this type signature is a possible return value from jl_cacheable_sig
//return jl_cacheable_sig(type, NULL, definition->sig, definition, NULL, NULL);

if (definition->isstaged)
if (definition->generator)
// staged functions can't be optimized
// so assume the caller was intelligent about calling us
return 1;
Expand Down Expand Up @@ -847,7 +846,7 @@ static jl_method_instance_t *cache_method(jl_methtable_t *mt, union jl_typemap_t
// in general, here we want to find the biggest type that's not a
// supertype of any other method signatures. so far we are conservative
// and the types we find should be bigger.
if (!definition->isstaged && jl_nparams(type) > mt->max_args
if (definition->generator == NULL && jl_nparams(type) > mt->max_args
&& jl_va_tuple_kind((jl_datatype_t*)decl) == JL_VARARG_UNBOUND) {
size_t i, nspec = mt->max_args + 2;
jl_svec_t *limited = jl_alloc_svec(nspec);
Expand Down Expand Up @@ -1625,7 +1624,7 @@ jl_llvm_functions_t jl_compile_for_dispatch(jl_method_instance_t **pli, size_t w
jl_options.compile_enabled == JL_OPTIONS_COMPILE_MIN) {
// copy fptr from the template method definition
jl_method_t *def = li->def.method;
if (jl_is_method(def) && !def->isstaged && def->unspecialized) {
if (jl_is_method(def) && def->unspecialized) {
if (def->unspecialized->jlcall_api == 2) {
li->functionObjectsDecls.functionObject = NULL;
li->functionObjectsDecls.specFunctionObject = NULL;
Expand Down
9 changes: 5 additions & 4 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,15 @@ jl_value_t *jl_interpret_call(jl_method_instance_t *lam, jl_value_t **args, uint
return lam->inferred_const;
jl_code_info_t *src = (jl_code_info_t*)lam->inferred;
if (!src || (jl_value_t*)src == jl_nothing) {
if (lam->def.method->isstaged) {
if (lam->def.method->source) {
src = (jl_code_info_t*)lam->def.method->source;
}
else {
assert(lam->def.method->generator);
src = jl_code_for_staged(lam);
lam->inferred = (jl_value_t*)src;
jl_gc_wb(lam, src);
}
else {
src = (jl_code_info_t*)lam->def.method->source;
}
}
if (src && (jl_value_t*)src != jl_nothing) {
src = jl_uncompress_ast(lam->def.method, (jl_array_t*)src);
Expand Down
6 changes: 2 additions & 4 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1920,7 +1920,7 @@ void jl_init_types(void)
jl_method_type =
jl_new_datatype(jl_symbol("Method"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(19,
jl_perm_symsvec(18,
"name",
"module",
"file",
Expand All @@ -1938,9 +1938,8 @@ void jl_init_types(void)
"nargs",
"called",
"isva",
"isstaged",
"pure"),
jl_svec(19,
jl_svec(18,
jl_sym_type,
jl_module_type,
jl_sym_type,
Expand All @@ -1958,7 +1957,6 @@ void jl_init_types(void)
jl_int32_type,
jl_int32_type,
jl_bool_type,
jl_bool_type,
jl_bool_type),
0, 1, 9);

Expand Down
3 changes: 1 addition & 2 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ typedef struct _jl_method_t {
jl_svec_t *sparam_syms; // symbols giving static parameter names
jl_value_t *source; // original code template (jl_code_info_t, but may be compressed), null for builtins
struct _jl_method_instance_t *unspecialized; // unspecialized executable method instance, or null
struct _jl_method_instance_t *generator; // executable code-generating function if isstaged
struct _jl_method_instance_t *generator; // executable code-generating function if available
jl_array_t *roots; // pointers in generated code (shared to reduce memory), or null

// cache of specializations of this method for invoke(), i.e.
Expand All @@ -256,7 +256,6 @@ typedef struct _jl_method_t {
int32_t nargs;
int32_t called; // bit flags: whether each of the first 8 arguments is called
uint8_t isva;
uint8_t isstaged;
uint8_t pure;

// hidden fields:
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ STATIC_INLINE jl_value_t *jl_compile_method_internal(jl_generic_fptr_t *fptr,
fptr->fptr = meth->unspecialized_ducttape;
fptr->jlcall_api = 1;
if (!fptr->fptr) {
if (jl_is_method(meth->def.method) && !meth->def.method->isstaged && meth->def.method->unspecialized) {
if (jl_is_method(meth->def.method) && meth->def.method->unspecialized) {
fptr->fptr = meth->def.method->unspecialized->fptr;
fptr->jlcall_api = meth->def.method->unspecialized->jlcall_api;
if (fptr->jlcall_api == 2) {
Expand Down
7 changes: 3 additions & 4 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)
jl_value_t *linenum = NULL;
jl_svec_t *sparam_vals = env;
jl_method_instance_t *generator = linfo->def.method->generator;
assert(generator != NULL);
assert(linfo != generator);
assert(linfo->def.method->isstaged);
jl_code_info_t *func = NULL;
JL_GC_PUSH4(&ex, &linenum, &sparam_vals, &func);
jl_ptls_t ptls = jl_get_ptls_states();
Expand All @@ -289,7 +289,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo)

jl_array_t *argnames = jl_alloc_vec_any(linfo->def.method->nargs);
jl_array_ptr_set(ex->args, 0, argnames);
jl_fill_argnames((jl_array_t*)linfo->def.method->source, argnames);
jl_fill_argnames((jl_array_t*)generator->inferred, argnames);

// build the rest of the body to pass to expand
jl_expr_t *scopeblock = jl_exprn(jl_symbol("scope-block"), 1);
Expand Down Expand Up @@ -466,7 +466,6 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
m->line = 0;
m->called = 0xff;
m->invokes.unknown = NULL;
m->isstaged = 0;
m->isva = 0;
m->nargs = 0;
m->traced = 0;
Expand Down Expand Up @@ -499,7 +498,6 @@ static jl_method_t *jl_new_method(
m->sparam_syms = sparam_syms;
root = (jl_value_t*)m;
m->min_world = ++jl_world_counter;
m->isstaged = isstaged;
m->name = name;
m->sig = (jl_value_t*)sig;
m->isva = isva;
Expand All @@ -510,6 +508,7 @@ static jl_method_t *jl_new_method(
m->generator = jl_get_specialized(m, (jl_value_t*)jl_anytuple_type, jl_emptysvec);
jl_gc_wb(m, m->generator);
m->generator->inferred = (jl_value_t*)m->source;
m->source = NULL;
}

#ifdef RECORD_METHOD_ORDER
Expand Down
4 changes: 2 additions & 2 deletions src/precompile.c
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ static void _compile_all_deq(jl_array_t *found)
jl_printf(JL_STDERR, " %d / %d\r", found_i + 1, found_l);
jl_typemap_entry_t *ml = (jl_typemap_entry_t*)jl_array_ptr_ref(found, found_i);
jl_method_t *m = ml->func.method;
if (m->isstaged) // TODO: generic implementations of generated functions
if (m->source == NULL) // TODO: generic implementations of generated functions
continue;
linfo = m->unspecialized;
if (!linfo) {
Expand Down Expand Up @@ -274,7 +274,7 @@ static int compile_all_enq__(jl_typemap_entry_t *ml, void *env)
jl_array_t *found = (jl_array_t*)env;
// method definition -- compile template field
jl_method_t *m = ml->func.method;
if (!m->isstaged &&
if (m->source &&
(!m->unspecialized ||
(m->unspecialized->functionObjectsDecls.functionObject == NULL &&
m->unspecialized->jlcall_api != 2 &&
Expand Down

0 comments on commit d91583d

Please sign in to comment.