Skip to content

Commit

Permalink
clean up the data flow around typeinf
Browse files Browse the repository at this point in the history
store LambdaInfo objects in the tfunc cache, instead of just ASTs
remove some arguments to jl_type_infer and typeinf
ensure every specialization has a new LambdaInfo; before we sometimes
  erroneously mutated the original LambdaInfo for a definition
  • Loading branch information
JeffBezanson committed Mar 2, 2016
1 parent 4f25c15 commit 2119e87
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 91 deletions.
96 changes: 61 additions & 35 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ function abstract_call_gf_by_type(f::ANY, argtype::ANY, e)
end
end
#print(m,"\n")
(_tree,rt) = typeinf(linfo, sig, m[2], linfo)
(_tree,rt) = typeinf(linfo, sig, m[2])
rettype = tmerge(rettype, rt)
if is(rettype,Any)
break
Expand Down Expand Up @@ -767,7 +767,7 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY)
if linfo === NF
return Any
end
return typeinf(linfo::LambdaInfo, ti, env, linfo)[2]
return typeinf(linfo::LambdaInfo, ti, env)[2]
end

# `types` is an array of inferred types for expressions in `args`.
Expand Down Expand Up @@ -888,7 +888,7 @@ function pure_eval_call(f::ANY, fargs, argtypes::ANY, sv, e)
return false
end
if !linfo.pure
typeinf(linfo, meth[1], meth[2], linfo)
typeinf(linfo, meth[1], meth[2])
if !linfo.pure
return false
end
Expand Down Expand Up @@ -1341,26 +1341,52 @@ end

is_rest_arg(arg::ANY) = (ccall(:jl_is_rest_arg,Int32,(Any,), arg) != 0)

function typeinf_ext(linfo, atypes::ANY, def)
function typeinf_ext(linfo, atypes::ANY)
global inference_stack
last = inference_stack
inference_stack = EmptyCallStack()
result = typeinf(linfo, atypes, svec(), def, true, true)
(newlinfo,ty) = typeinf(linfo, atypes, svec(), true, true)
inference_stack = last
return result
linfo.inferred = newlinfo !== linfo.def
if linfo.inferred && newlinfo !== linfo
linfo.rettype = ty
# if type inference bails out it returns def.ast
linfo.ast = newlinfo.ast
linfo.nslots = newlinfo.nslots
linfo.ngensym = newlinfo.ngensym
linfo.pure = newlinfo.pure
end
nothing
end

typeinf(linfo,atypes::ANY,sparams::ANY) = typeinf(linfo,atypes,sparams,linfo,true,false)
typeinf(linfo,atypes::ANY,sparams::ANY,def) = typeinf(linfo,atypes,sparams,def,true,false)
# copy a LambdaInfo just enough to make it not share data with li.def
function unshare_linfo(li::LambdaInfo, inplace)
if !inplace && li === li.def
li = ccall(:jl_copy_lambda_info, Any, (Any,), li)::LambdaInfo
end
if !isa(li.ast, Expr)
li.ast = ccall(:jl_uncompress_ast, Any, (Any,Any), li, li.ast)
elseif li.ast === li.def.ast
li.ast = astcopy(li.ast)
end
return li
end

function compress!(li::LambdaInfo)
if isa(li.ast, Expr)
li.ast = ccall(:jl_compress_ast, Any, (Any,Any), li.def, li.ast)
end
li
end

CYCLE_ID = 1

#trace_inf = false
#enable_trace_inf(on) = (global trace_inf=on)

# def is the original unspecialized version of a method. we aggregate all
# linfo.def is the original unspecialized version of a method. we aggregate all
# saved type inference data there.
function typeinf(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, def, cop, needtree)
function typeinf(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, needtree=false, inplace=false)
if linfo.module === Core && isempty(sparams) && isempty(linfo.sparam_vals)
atypes = Tuple
end
Expand All @@ -1370,6 +1396,7 @@ function typeinf(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, def, cop
curtype = Bottom
redo = false
# check cached t-functions
def = linfo.def
tf = def.tfunc
if !is(tf,nothing)
tfarr = tf::Array{Any,1}
Expand All @@ -1390,19 +1417,19 @@ function typeinf(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, def, cop
return (nothing, code)
end
else
return code # else code is a tuple (ast, type)
return (code, code.rettype) # else code is a LambdaInfo
end
end
end
end
# TODO: typeinf currently gets stuck without this
if linfo.name === :abstract_interpret || linfo.name === :alloc_elim_pass || linfo.name === :abstract_call_gf
return (linfo.ast, Any)
return (linfo, Any)
end

(fulltree, result, rec) = typeinf_uncached(linfo, atypes, sparams, def, curtype, cop, true)
if fulltree === ()
return (fulltree, result::Type)
(newcode, result, rec) = typeinf_uncached(linfo, atypes, sparams, curtype, true, inplace)
if newcode === nothing
return (newcode, result::Type)
end

if !redo
Expand All @@ -1424,24 +1451,25 @@ function typeinf(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, def, cop
tfarr[idx] = atypes
# in the "rec" state this tree will not be used again, so store
# just the return type in place of it.
tfarr[idx+1] = rec ? result : (fulltree,result)
tfarr[idx+1] = rec ? result : newcode
tfarr[idx+2] = rec
else
def.tfunc[tfunc_idx] = rec ? result : (fulltree,result)
def.tfunc[tfunc_idx] = rec ? result : newcode
def.tfunc[tfunc_idx+1] = rec
end

return (fulltree, result::Type)
return (newcode, result::Type)
end

typeinf_uncached(linfo, atypes::ANY, sparams::ANY; optimize=true) =
typeinf_uncached(linfo, atypes, sparams, linfo, Bottom, true, optimize)
typeinf_uncached(linfo, atypes, sparams, Bottom, optimize, false)

# t[n:end]
tupletype_tail(t::ANY, n) = Tuple{t.parameters[n:end]...}

# compute an inferred (optionally optimized) AST without global effects (i.e. updating the cache)
function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, def, curtype, cop, optimize)
function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector, curtype, optimize, inplace)
def = linfo.def
ast0 = def.ast
#if dbg
# print("typeinf ", linfo.name, " ", object_id(ast0), "\n")
Expand Down Expand Up @@ -1510,7 +1538,7 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,
end
CYCLE_ID += 1
#print("*==> ", f.result,"\n")
return ((),f.result,true)
return (nothing,f.result,true)
end
f = f.prev
end
Expand All @@ -1521,11 +1549,8 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,

#if dbg print("typeinf ", linfo.name, " ", atypes, "\n") end

if cop
ast = ccall(:jl_prepare_ast, Any, (Any,), linfo)::Expr
else
ast = linfo.ast
end
linfo = unshare_linfo(linfo, inplace)
ast = linfo.ast

sv = VarInfo(linfo, ast)

Expand Down Expand Up @@ -1599,7 +1624,7 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,
s[1][i] = VarState(lastatype, false)
end
elseif la != 0
return ((), Bottom, false) # wrong number of arguments
return (linfo, Bottom, false) # wrong number of arguments
end

gensym_uses = find_gensym_uses(body)
Expand Down Expand Up @@ -1747,7 +1772,7 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,
# for an example see test/libgit2.jl on 0.5-pre master
# around e.g. commit c072d1ce73345e153e4fddf656cda544013b1219
inference_stack = (inference_stack::CallStack).prev
return (ast0, Any, false)
return (def, Any, false)
end
end
handler_at[l] = cur_hand
Expand Down Expand Up @@ -1805,15 +1830,17 @@ function typeinf_uncached(linfo::LambdaInfo, atypes::ANY, sparams::SimpleVector,
linfo.nslots = length(fulltree.args[2][1])
linfo.ngensym = length(sv.gensym_types)
end
linfo.inferred = true
body = Expr(:block)
body.args = fulltree.args[3].args::Array{Any,1}
linfo.pure = popmeta!(body, :pure)[1]
fulltree = ccall(:jl_compress_ast, Any, (Any,Any), def, fulltree)
linfo.ast = fulltree
compress!(linfo)
linfo.rettype = frame.result
linfo.inferred = true
end

inference_stack = (inference_stack::CallStack).prev
return (fulltree, frame.result, rec)
return (linfo, frame.result, rec)
end

function record_var_type(s::Slot, t::ANY, decls)
Expand Down Expand Up @@ -2243,15 +2270,14 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::VarInfo,
methargs = metharg.parameters
nm = length(methargs)

(ast, ty) = typeinf(linfo, metharg, methsp, linfo, true, true)
if is(ast,())
(linfo, ty) = typeinf(linfo, metharg, methsp, true)
if is(linfo,nothing)
return NF
end
ast = linfo.ast

if !isa(ast,Expr)
ast = ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, ast)
else
ast = astcopy(ast)
end
ast = ast::Expr
vinflist = ast.args[2][1]::Array{Any,1}
Expand Down
13 changes: 4 additions & 9 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,11 @@ function code_typed(f::ANY, types::ANY=Tuple; optimize=true)
for x in _methods(f,types,-1)
linfo = func_for_method_checked(x, types)
if optimize
(tree, ty) = Core.Inference.typeinf(linfo, x[1], x[2], linfo,
true, true)
(li, ty) = Core.Inference.typeinf(linfo, x[1], x[2], true)
else
(tree, ty) = Core.Inference.typeinf_uncached(linfo, x[1], x[2],
optimize=false)
(li, ty) = Core.Inference.typeinf_uncached(linfo, x[1], x[2], optimize=false)
end
if !isa(tree, Expr)
tree = ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, tree)
end
push!(asts, tree)
push!(asts, li)
end
asts
end
Expand All @@ -314,7 +309,7 @@ function return_types(f::ANY, types::ANY=Tuple)
rt = []
for x in _methods(f,types,-1)
linfo = func_for_method_checked(x,types)
(tree, ty) = Core.Inference.typeinf(linfo, x[1], x[2])
(_li, ty) = Core.Inference.typeinf(linfo, x[1], x[2])
push!(rt, ty)
end
rt
Expand Down
2 changes: 1 addition & 1 deletion src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ jl_lambda_info_t *jl_new_lambda_info(jl_value_t *ast, jl_svec_t *tvars, jl_svec_
return li;
}

jl_lambda_info_t *jl_copy_lambda_info(jl_lambda_info_t *linfo)
JL_DLLEXPORT jl_lambda_info_t *jl_copy_lambda_info(jl_lambda_info_t *linfo)
{
jl_lambda_info_t *new_linfo =
jl_new_lambda_info(linfo->ast, linfo->sparam_syms, linfo->sparam_vals, linfo->module);
Expand Down
19 changes: 0 additions & 19 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -1084,25 +1084,6 @@ JL_DLLEXPORT jl_value_t *jl_copy_ast(jl_value_t *expr)
return expr;
}

// given a new lambda_info with static parameter values, make a copy
// of the tree with declared types evaluated and static parameters passed
// on to all enclosed functions.
// this tree can then be further mutated by optimization passes.
JL_DLLEXPORT jl_value_t *jl_prepare_ast(jl_lambda_info_t *li)
{
jl_value_t *ast = li->ast;
if (ast == NULL) return NULL;
JL_GC_PUSH1(&ast);
if (!jl_is_expr(ast)) {
ast = jl_uncompress_ast(li, ast);
}
else {
ast = jl_copy_ast(ast);
}
JL_GC_POP();
return ast;
}

JL_DLLEXPORT int jl_is_operator(char *sym)
{
jl_ast_context_t *ctx = jl_ast_ctx_enter();
Expand Down
6 changes: 3 additions & 3 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -814,10 +814,10 @@ static void jl_serialize_value_(ios_t *s, jl_value_t *v)
for(i=0; i < l; i += 3) {
if (!jl_is_leaf_type(jl_cellref(tf,i))) {
jl_value_t *ret = jl_cellref(tf,i+1);
if (jl_is_tuple(ret)) {
jl_value_t *ast = jl_fieldref(ret, 0);
if (jl_is_lambda_info(ret)) {
jl_value_t *ast = ((jl_lambda_info_t*)ret)->ast;
if (jl_is_array(ast) && jl_array_len(ast) > 500)
jl_cellset(tf, i+1, jl_fieldref(ret,1));
jl_cellset(tf, i+1, ((jl_lambda_info_t*)ret)->rettype);
}
}
}
Expand Down
29 changes: 9 additions & 20 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ jl_lambda_info_t *jl_get_unspecialized(jl_lambda_info_t *method)
method->unspecialized = jl_add_static_parameters(def, method->sparam_vals, method->specTypes);
jl_gc_wb(method, method->unspecialized);
method->unspecialized->unspecialized = method->unspecialized;
def = method;
return method->unspecialized;
}
if (def->unspecialized == NULL) {
def->unspecialized = jl_add_static_parameters(def, jl_emptysvec, jl_anytuple_type);
Expand Down Expand Up @@ -389,13 +389,9 @@ jl_lambda_info_t *jl_method_cache_insert(jl_methtable_t *mt, jl_tupletype_t *typ
return jl_method_list_insert(pml, type, method, jl_emptysvec, 0, 0, cache_array ? cache_array : (jl_value_t*)mt)->func;
}

/*
run type inference on lambda "li" in-place, for given argument types.
"def" is the original method definition of which this is an instance;
can be equal to "li" if not applicable.
*/
// run type inference on lambda "li" in-place, for given argument types.
int jl_in_inference = 0;
void jl_type_infer(jl_lambda_info_t *li, jl_lambda_info_t *def)
void jl_type_infer(jl_lambda_info_t *li)
{
JL_LOCK(codegen); // Might GC
int last_ii = jl_in_inference;
Expand All @@ -406,25 +402,17 @@ void jl_type_infer(jl_lambda_info_t *li, jl_lambda_info_t *def)
// called
assert(li->inInference == 0);
li->inInference = 1;
jl_value_t *fargs[4];
jl_value_t *fargs[3];
fargs[0] = (jl_value_t*)jl_typeinf_func;
fargs[1] = (jl_value_t*)li;
fargs[2] = (jl_value_t*)li->specTypes;
fargs[3] = (jl_value_t*)def;
#ifdef TRACE_INFERENCE
jl_printf(JL_STDERR,"inference on ");
jl_static_show_func_sig(JL_STDERR, (jl_value_t*)li->specTypes);
jl_printf(JL_STDERR, "\n");
#endif
#ifdef ENABLE_INFERENCE
jl_value_t *newast = jl_apply(fargs, 4);
jl_value_t *defast = def->ast;
li->ast = jl_fieldref(newast, 0);
jl_gc_wb(li, li->ast);
li->rettype = jl_fieldref(newast, 1);
jl_gc_wb(li, li->rettype);
// if type inference bails out it returns def->ast
li->inferred = li->ast != defast;
(void)jl_apply(fargs, 3);
#endif
li->inInference = 0;
}
Expand Down Expand Up @@ -828,7 +816,7 @@ static jl_lambda_info_t *cache_method(jl_methtable_t *mt, jl_tupletype_t *type,
jl_gc_wb(method, method->specializations);
if (jl_options.compile_enabled != JL_OPTIONS_COMPILE_OFF) // don't bother with typeinf if compile is off
if (jl_symbol_name(newmeth->name)[0] != '@') // don't bother with typeinf on macros
jl_type_infer(newmeth, method);
jl_type_infer(newmeth);
}
JL_GC_POP();
JL_UNLOCK(codegen);
Expand Down Expand Up @@ -891,7 +879,7 @@ JL_DLLEXPORT jl_lambda_info_t *jl_instantiate_staged(jl_lambda_info_t *generator
assert(jl_svec_len(generator->sparam_syms) == jl_svec_len(sparam_vals));
assert(generator->unspecialized == NULL && generator->specTypes == jl_anytuple_type);
//if (!generated->inferred)
// jl_type_infer(generator, generator); // this doesn't help all that much
// jl_type_infer(generator); // this doesn't help all that much

ex = jl_exprn(lambda_sym, 2);

Expand Down Expand Up @@ -930,6 +918,7 @@ JL_DLLEXPORT jl_lambda_info_t *jl_instantiate_staged(jl_lambda_info_t *generator
// need to eval macros in the right module, but not give a warning for the `eval` call unless that results in a call to `eval`
jl_lambda_info_t *func = (jl_lambda_info_t*)jl_toplevel_eval_in_warn(generator->module, (jl_value_t*)ex, 1);
func->name = generator->name;
func->def = generator;
JL_GC_POP();
return func;
}
Expand Down Expand Up @@ -1662,7 +1651,7 @@ static void _compile_all_deq(jl_array_t *found)
linfo = jl_get_unspecialized(linfo);
if (!linfo->inferred) {
// force this function to be recompiled
jl_type_infer(linfo, linfo->def);
jl_type_infer(linfo);
linfo->functionObjects.functionObject = NULL;
linfo->functionObjects.specFunctionObject = NULL;
linfo->functionObjects.cFunctionList = NULL;
Expand Down
Loading

0 comments on commit 2119e87

Please sign in to comment.