Skip to content

Commit

Permalink
Inline invoke (take 3)
Browse files Browse the repository at this point in the history
Fix #9608
  • Loading branch information
yuyichao committed Sep 11, 2016
1 parent ab3e756 commit cdafa58
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 33 deletions.
178 changes: 147 additions & 31 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -632,9 +632,9 @@ function invoke_tfunc(f::ANY, types::ANY, argtype::ANY, sv::InferenceState)
if is(entry, nothing)
return Any
end
meth = entry.func
(ti, env) = ccall(:jl_match_method, Any, (Any, Any, Any),
argtype, meth.sig, meth.tvars)::SimpleVector
meth = (entry::TypeMapEntry).func
(ti, env) = ccall(:jl_match_method, Ref{SimpleVector}, (Any, Any, Any),
argtype, meth.sig, meth.tvars)
return typeinf_edge(meth::Method, ti, env, sv)[2]
end

Expand Down Expand Up @@ -2335,14 +2335,32 @@ end

#### post-inference optimizations ####

function inline_as_constant(val::ANY, argexprs, linfo::LambdaInfo)
immutable InvokeData
mt::MethodTable
entry::TypeMapEntry
types0
fexpr
texpr
end

function inline_as_constant(val::ANY, argexprs, linfo::LambdaInfo,
invoke_data::ANY)
invoke_fexpr, invoke_texpr = if invoke_data === nothing
nothing, nothing
else
invoke_data = invoke_data::InvokeData
invoke_data.fexpr, invoke_data.texpr
end
# check if any arguments aren't effect_free and need to be kept around
stmts = Any[]
stmts = invoke_fexpr === nothing ? [] : Any[invoke_fexpr]
for i = 1:length(argexprs)
arg = argexprs[i]
if !effect_free(arg, linfo, false)
push!(stmts, arg)
end
if i == 1 && !(invoke_texpr === nothing)
push!(stmts, invoke_texpr)
end
end
return (QuoteNode(val), stmts)
end
Expand All @@ -2357,11 +2375,29 @@ function countunionsplit(atypes::Vector{Any})
return nu
end

function get_spec_lambda(atypes::ANY, invoke_data::ANY)
if invoke_data === nothing
return ccall(:jl_get_spec_lambda, Any, (Any,), atypes)
else
invoke_data = invoke_data::InvokeData
# TODO compute intersection and throws an error
atypes <: invoke_data.types0 || return nothing
return ccall(:jl_get_invoke_lambda, Any, (Any, Any, Any),
invoke_data.mt, invoke_data.entry, atypes)
end
end

function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing,
atype_unlimited::ANY)
atype_unlimited::ANY, invoke_data::ANY)
# converts a :call to :invoke
nu = countunionsplit(atypes)
nu > MAX_UNION_SPLITTING && return NF
invoke_fexpr, invoke_texpr = if invoke_data === nothing
nothing, nothing
else
invoke_data = invoke_data::InvokeData
invoke_data.fexpr, invoke_data.texpr
end

if nu > 1
spec_hit = nothing
Expand All @@ -2373,22 +2409,31 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing,
ex.typ = etype
stmts = []
arg_hoisted = false
arg0_hoisted = false
for i = length(atypes):-1:1
if i == 1 && !(invoke_texpr === nothing)
unshift!(stmts, invoke_texpr)
arg_hoisted = true
end
ti = atypes[i]
if arg_hoisted || isa(ti, Union)
aei = ex.args[i]
if !effect_free(aei, enclosing, false)
arg_hoisted = true
newvar = newvar!(sv, ti)
insert!(stmts, 1, :($newvar = $aei))
unshift!(stmts, :($newvar = $aei))
ex.args[i] = newvar
if i == 1
arg0_hoisted = true
end
end
end
end
invoke_fexpr === nothing || unshift!(stmts, invoke_fexpr)
function splitunion(atypes::Vector{Any}, i::Int)
if i == 0
local sig = argtypes_to_type(atypes)
local li = ccall(:jl_get_spec_lambda, Any, (Any,), sig)
local li = get_spec_lambda(sig, invoke_data)
li === nothing && return false
local stmt = []
push!(stmt, Expr(:(=), linfo_var, li))
Expand Down Expand Up @@ -2455,13 +2500,24 @@ function invoke_NF(argexprs, etype::ANY, atypes, sv, enclosing,
return (ret_var, stmts)
end
else
local cache_linfo = ccall(:jl_get_spec_lambda, Any, (Any,), atype_unlimited)
local cache_linfo = get_spec_lambda(atype_unlimited, invoke_data)
cache_linfo === nothing && return NF
unshift!(argexprs, cache_linfo)
ex = Expr(:invoke)
ex.args = argexprs
ex.typ = etype
return ex
if invoke_texpr === nothing
if invoke_fexpr === nothing
return ex
else
return ex, Any[invoke_fexpr]
end
end
newvar = newvar!(sv, atypes[1])
stmts = Any[invoke_fexpr, :($newvar = $(argexprs[1])),
invoke_texpr]
argexprs[1] = newvar
return ex, stmts
end
return NF
end
Expand Down Expand Up @@ -2515,45 +2571,98 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end
end
if isa(f, IntrinsicFunction) || ft IntrinsicFunction ||
invoke_data = nothing
invoke_fexpr = nothing
invoke_texpr = nothing
if f === Core.invoke && length(atypes) >= 3
ft = widenconst(atypes[2])
invoke_tt = widenconst(atypes[3])
if !isleaftype(ft) || !isleaftype(invoke_tt) || !isType(invoke_tt)
return NF
end
if !(isa(invoke_tt.parameters[1], Type) &&
invoke_tt.parameters[1] <: Tuple)
return NF
end
invoke_tt_params = invoke_tt.parameters[1].parameters
invoke_types = Tuple{ft, invoke_tt_params...}
invoke_entry = ccall(:jl_gf_invoke_lookup, Any, (Any,), invoke_types)
invoke_entry === nothing && return NF
invoke_fexpr = argexprs[1]
invoke_texpr = argexprs[3]
if effect_free(invoke_fexpr, enclosing, false)
invoke_fexpr = nothing
end
if effect_free(invoke_texpr, enclosing, false)
invoke_fexpr = nothing
end
invoke_data = InvokeData(ft.name.mt, invoke_entry,
invoke_types, invoke_fexpr, invoke_texpr)
atype0 = atypes[2]
argexpr0 = argexprs[2]
atypes = atypes[4:end]
argexprs = argexprs[4:end]
unshift!(atypes, atype0)
unshift!(argexprs, argexpr0)
f = isdefined(ft, :instance) ? ft.instance : nothing
elseif isa(f, IntrinsicFunction) || ft IntrinsicFunction ||
isa(f, Builtin) || ft Builtin
return NF
end

local atype_unlimited = argtypes_to_type(atypes)
atype_unlimited = argtypes_to_type(atypes)
if !(invoke_data === nothing)
invoke_data = invoke_data::InvokeData
# TODO emit a type check and proceed for this case
atype_unlimited <: invoke_data.types0 || return NF
end
if !sv.inlining
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited)
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
end

if length(atype_unlimited.parameters) - 1 > MAX_TUPLETYPE_LEN
atype = limit_tuple_type(atype_unlimited)
else
atype = atype_unlimited
end
meth = _methods_by_ftype(atype, 1)
if meth === false || length(meth) != 1
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited)
if invoke_data === nothing
meth = _methods_by_ftype(atype, 1)
if meth === false || length(meth) != 1
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
end
meth = meth[1]::SimpleVector
metharg = meth[1]::Type
methsp = meth[2]::SimpleVector
method = meth[3]::Method
else
invoke_data = invoke_data::InvokeData
method = invoke_data.entry.func
(metharg, methsp) = ccall(:jl_match_method, Ref{SimpleVector},
(Any, Any, Any),
atype_unlimited, method.sig, method.tvars)
methsp = methsp::SimpleVector
end
meth = meth[1]::SimpleVector
metharg = meth[1]::Type
methsp = meth[2]
method = meth[3]::Method
# check whether call can be inlined to just a quoted constant value
if isa(f, widenconst(ft)) && !method.isstaged && (method.lambda_template.pure || f === return_type) &&
(isType(e.typ) || isa(e.typ,Const))
if isType(e.typ)
if !has_typevars(e.typ.parameters[1])
return inline_as_constant(e.typ.parameters[1], argexprs, enclosing)
return inline_as_constant(e.typ.parameters[1], argexprs,
enclosing, invoke_data)
end
else
assert(isa(e.typ,Const))
return inline_as_constant(e.typ.val, argexprs, enclosing)
return inline_as_constant(e.typ.val, argexprs, enclosing,
invoke_data)
end
end

methsig = method.sig
if !(atype <: metharg)
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing, atype_unlimited)
return invoke_NF(argexprs, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
end

argexprs0 = argexprs
Expand Down Expand Up @@ -2583,18 +2692,22 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference

(linfo, ty, inferred) = typeinf(method, metharg, methsp, false)
if linfo === nothing || !inferred
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, atype_unlimited)
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
end
if linfo !== nothing && linfo.jlcall_api == 2
# in this case function can be inlined to a constant
return inline_as_constant(linfo.constval, argexprs, enclosing)
return inline_as_constant(linfo.constval, argexprs, enclosing,
invoke_data)
elseif linfo !== nothing && !linfo.inlineable
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, atype_unlimited)
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
elseif linfo === nothing || linfo.code === nothing
(linfo, ty, inferred) = typeinf(method, metharg, methsp, true)
end
if linfo === nothing || !inferred || !linfo.inlineable || (ast = linfo.code) === nothing
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing, atype_unlimited)
return invoke_NF(argexprs0, e.typ, atypes, sv, enclosing,
atype_unlimited, invoke_data)
end

spvals = Any[]
Expand All @@ -2608,8 +2721,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end

methargs = metharg.parameters
nm = length(methargs)
nm = length(metharg.parameters)

if !isa(ast, Array{Any,1})
ast = ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, ast)
Expand All @@ -2623,14 +2735,17 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
propagate_inbounds = linfo.propagate_inbounds

# see if each argument occurs only once in the body expression
stmts = Any[]
prelude_stmts = Any[]
stmts = []
prelude_stmts = []
stmts_free = true # true = all entries of stmts are effect_free

for i=na:-1:1 # stmts_free needs to be calculated in reverse-argument order
#args_i = args[i]
aei = argexprs[i]
aeitype = argtype = widenconst(exprtype(aei, enclosing))
if i == 1 && !(invoke_texpr === nothing)
unshift!(prelude_stmts, invoke_texpr)
end

# ok for argument to occur more than once if the actual argument
# is a symbol or constant, or is not affected by previous statements
Expand Down Expand Up @@ -2664,6 +2779,7 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference
end
end
end
invoke_fexpr === nothing || unshift!(prelude_stmts, invoke_fexpr)

# re-number the SSAValues and copy their type-info to the new ast
ssavalue_types = linfo.ssavaluetypes
Expand Down
51 changes: 49 additions & 2 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1967,11 +1967,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_datatype_t *types)
jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs)
{
jl_svec_t *tpenv = jl_emptysvec;
jl_tupletype_t *newsig = NULL;
jl_tupletype_t *tt = NULL;
jl_tupletype_t *types = NULL;
jl_tupletype_t *sig = NULL;
JL_GC_PUSH5(&types, &tpenv, &newsig, &sig, &tt);
JL_GC_PUSH4(&types, &tpenv, &sig, &tt);
jl_value_t *gf = args[0];
types = (jl_datatype_t*)jl_argtype_with_function(gf, (jl_tupletype_t*)types0);
jl_methtable_t *mt = jl_gf_mtable(gf);
Expand Down Expand Up @@ -2022,6 +2021,54 @@ jl_value_t *jl_gf_invoke(jl_tupletype_t *types0, jl_value_t **args, size_t nargs
return jl_call_method_internal(mfunc, args, nargs);
}

JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
jl_typemap_entry_t *entry,
jl_tupletype_t *tt)
{
if (!jl_is_leaf_type((jl_value_t*)tt))
return jl_nothing;
jl_method_t *method = entry->func.method;
jl_typemap_entry_t *tm = NULL;
if (method->invokes.unknown != NULL) {
tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1,
jl_cachearg_offset(mt));
if (tm) {
return (jl_value_t*)tm->func.linfo;
}
}

JL_LOCK(&method->writelock);
if (method->invokes.unknown != NULL) {
tm = jl_typemap_assoc_by_type(method->invokes, tt, NULL, 0, 1,
jl_cachearg_offset(mt));
if (tm) {
jl_lambda_info_t *mfunc = tm->func.linfo;
JL_UNLOCK(&method->writelock);
return (jl_value_t*)mfunc;
}
}
jl_svec_t *tpenv = jl_emptysvec;
jl_tupletype_t *sig = NULL;
JL_GC_PUSH2(&tpenv, &sig);
if (entry->tvars != jl_emptysvec) {
jl_value_t *ti =
jl_lookup_match((jl_value_t*)tt, (jl_value_t*)entry->sig, &tpenv, entry->tvars);
assert(ti != (jl_value_t*)jl_bottom_type);
(void)ti;
}
sig = join_tsig(tt, entry->sig);
jl_method_t *func = entry->func.method;

if (func->invokes.unknown == NULL)
func->invokes.unknown = jl_nothing;

jl_lambda_info_t *mfunc = cache_method(mt, &func->invokes, entry->func.value,
sig, tt, entry, tpenv, 1);
JL_GC_POP();
JL_UNLOCK(&method->writelock);
return (jl_value_t*)mfunc;
}

static jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw)
{
// type name is function name prefixed with #
Expand Down

0 comments on commit cdafa58

Please sign in to comment.