diff --git a/base/inference.jl b/base/inference.jl index ad52c02a2efcb..298a84d679dab 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -2875,9 +2875,7 @@ function isinlineable(m::Method, src::CodeInfo) end end if !inlineable - body = Expr(:block) - body.args = src.code - inlineable = inline_worthy(body, cost) + inlineable = inline_worthy_stmts(src.code, cost) end return inlineable end @@ -3661,7 +3659,10 @@ end # static parameters are ok if all the static parameter values are leaf types, # meaning they are fully known. # `ft` is the type of the function. `f` is the exact function if known, or else `nothing`. -function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::InferenceState) +# `pending_stmts` is an array of statements from functions inlined so far, so +# we can estimate the total size of the enclosing function after inlining. +function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::InferenceState, + pending_stmts) argexprs = e.args if (f === typeassert || ft ⊑ typeof(typeassert)) && length(atypes)==3 @@ -3932,6 +3933,34 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference invoke_data) end + if !isa(ast, Array{Any,1}) + ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast) + else + ast = copy_exprargs(ast) + end + ast = ast::Array{Any,1} + + # `promote` is a tuple-returning function that is very important to inline + if isdefined(Main, :Base) && isdefined(Main.Base, :promote) && + length(sv.src.slottypes) > 0 && sv.src.slottypes[1] ⊑ typeof(getfield(Main.Base, :promote)) + # check for non-isbits Tuple return + if sv.bestguess ⊑ Tuple && !isbits(widenconst(sv.bestguess)) + # See if inlining this call would change the enclosing function + # from inlineable to not inlineable. + # This heuristic is applied to functions that return non-bits + # tuples, since we want to be able to inline those functions to + # avoid the tuple allocation. + current_stmts = vcat(sv.src.code, pending_stmts) + if inline_worthy_stmts(current_stmts) + append!(current_stmts, ast) + if !inline_worthy_stmts(current_stmts) + return invoke_NF(argexprs0, e.typ, atypes, sv, atype_unlimited, + invoke_data) + end + end + end + end + # create the backedge if isa(frame, InferenceState) && !frame.inferred && frame.cached # in this case, the actual backedge linfo hasn't been computed @@ -3954,13 +3983,6 @@ function inlineable(f::ANY, ft::ANY, e::Expr, atypes::Vector{Any}, sv::Inference nm = length(unwrap_unionall(metharg).parameters) - if !isa(ast, Array{Any,1}) - ast = ccall(:jl_uncompress_ast, Any, (Any, Any), method, ast) - else - ast = copy_exprargs(ast) - end - ast = ast::Array{Any,1} - body = Expr(:block) body.args = ast propagate_inbounds = src.propagate_inbounds @@ -4187,10 +4209,13 @@ function inline_ignore(ex::ANY) return isa(ex, Expr) && is_meta_expr(ex::Expr) end +function inline_worthy_stmts(stmts::Vector{Any}, cost::Integer = 1000) + body = Expr(:block) + body.args = stmts + return inline_worthy(body, cost) +end + function inline_worthy(body::Expr, cost::Integer=1000) # precondition: 0 < cost; nominal cost = 1000 - if popmeta!(body, :noinline)[1] - return false - end symlim = 1000 + 5_000_000 ÷ cost nstmt = 0 for stmt in body.args @@ -4238,17 +4263,15 @@ end function inlining_pass!(sv::InferenceState) eargs = sv.src.code i = 1 + stmtbuf = [] while i <= length(eargs) ei = eargs[i] if isa(ei, Expr) - res = inlining_pass(ei, sv) - eargs[i] = res[1] - if isa(res[2], Array) - sts = res[2]::Array{Any,1} - for j = 1:length(sts) - insert!(eargs, i, sts[j]) - i += 1 - end + eargs[i] = inlining_pass(ei, sv, stmtbuf, 1) + if !isempty(stmtbuf) + splice!(eargs, i:i-1, stmtbuf) + i += length(stmtbuf) + empty!(stmtbuf) end end i += 1 @@ -4257,16 +4280,17 @@ end const corenumtype = Union{Int32, Int64, Float32, Float64} -function inlining_pass(e::Expr, sv::InferenceState) +# return inlined replacement for `e`, inserting new needed statements +# at index `ins` in `stmts`. +function inlining_pass(e::Expr, sv::InferenceState, stmts, ins) if e.head === :method # avoid running the inlining pass on function definitions - return (e, ()) + return e end eargs = e.args if length(eargs) < 1 - return (e, ()) + return e end - stmts = [] arg1 = eargs[1] isccall = false i0 = 1 @@ -4281,6 +4305,7 @@ function inlining_pass(e::Expr, sv::InferenceState) i0 = 5 end has_stmts = false # needed to preserve order-of-execution + prev_stmts_length = length(stmts) for _i = length(eargs):-1:i0 if isccall && _i == 3 i = 1 @@ -4303,40 +4328,33 @@ function inlining_pass(e::Expr, sv::InferenceState) else argloc = eargs end - res = inlining_pass(ei, sv) - res1 = res[1] - res2 = res[2] - has_new_stmts = isa(res2, Array) && !isempty(res2::Array{Any,1}) + sl0 = length(stmts) + res = inlining_pass(ei, sv, stmts, ins) + ns = length(stmts) - sl0 # number of new statements just added if isccallee - restype = exprtype(res1, sv.src, sv.mod) + restype = exprtype(res, sv.src, sv.mod) if isa(restype, Const) argloc[i] = restype.val - if !effect_free(res1, sv.src, sv.mod, false) - insert!(stmts, 1, res1) - end - if has_new_stmts - prepend!(stmts, res2::Array{Any,1}) + if !effect_free(res, sv.src, sv.mod, false) + insert!(stmts, ins+ns, res) end # Assume this is the last argument to process break end end - if has_stmts && !effect_free(res1, sv.src, sv.mod, false) - restype = exprtype(res1, sv.src, sv.mod) + if has_stmts && !effect_free(res, sv.src, sv.mod, false) + restype = exprtype(res, sv.src, sv.mod) vnew = newvar!(sv, restype) argloc[i] = vnew - unshift!(stmts, Expr(:(=), vnew, res1)) + insert!(stmts, ins+ns, Expr(:(=), vnew, res)) else - argloc[i] = res1 - end - if has_new_stmts - res2 = res2::Array{Any,1} - prepend!(stmts, res2) - if !has_stmts && !(_i == i0) - for stmt in res2 - if !effect_free(stmt, sv.src, sv.mod, true) - has_stmts = true - end + argloc[i] = res + end + if !has_stmts && ns > 0 && !(_i == i0) + for s = ins:ins+ns-1 + stmt = stmts[s] + if !effect_free(stmt, sv.src, sv.mod, true) + has_stmts = true; break end end end @@ -4351,7 +4369,7 @@ function inlining_pass(e::Expr, sv::InferenceState) end end if e.head !== :call - return (e, stmts) + return e end ft = exprtype(arg1, sv.src, sv.mod) @@ -4363,10 +4381,12 @@ function inlining_pass(e::Expr, sv::InferenceState) else f = nothing if !( isleaftype(ft) || ft<:Type ) - return (e, stmts) + return e end end + ins += (length(stmts) - prev_stmts_length) + if sv.params.inlining if isdefined(Main, :Base) && ((isdefined(Main.Base, :^) && f === Main.Base.:^) || @@ -4390,19 +4410,13 @@ function inlining_pass(e::Expr, sv::InferenceState) exprtype(a1, sv.src, sv.mod) ⊑ basenumtype) if square e.args = Any[GlobalRef(Main.Base,:*), a1, a1] - res = inlining_pass(e, sv) + res = inlining_pass(e, sv, stmts, ins) else e.args = Any[GlobalRef(Main.Base,:*), Expr(:call, GlobalRef(Main.Base,:*), a1, a1), a1] e.args[2].typ = e.typ - res = inlining_pass(e, sv) - end - if isa(res, Tuple) - if isa(res[2], Array) && !isempty(res[2]) - append!(stmts, res[2]) - end - res = res[1] + res = inlining_pass(e, sv, stmts, ins) end - return (res, stmts) + return res end end end @@ -4413,13 +4427,14 @@ function inlining_pass(e::Expr, sv::InferenceState) ata[1] = ft for i = 2:length(e.args) a = exprtype(e.args[i], sv.src, sv.mod) - (a === Bottom || isvarargtype(a)) && return (e, stmts) + (a === Bottom || isvarargtype(a)) && return e ata[i] = a end - res = inlineable(f, ft, e, ata, sv) + res = inlineable(f, ft, e, ata, sv, stmts) if isa(res,Tuple) if isa(res[2],Array) && !isempty(res[2]) - append!(stmts,res[2]) + splice!(stmts, ins:ins-1, res[2]) + ins += length(res[2]) end res = res[1] end @@ -4431,7 +4446,7 @@ function inlining_pass(e::Expr, sv::InferenceState) e = res::Expr f = _apply; ft = abstract_eval_constant(f) else - return (res,stmts) + return res end end @@ -4453,7 +4468,7 @@ function inlining_pass(e::Expr, sv::InferenceState) newargs[i-2] = Any[ mk_getfield(aarg,j,tp[j]) for j=1:length(tp) ] else # not all args expandable - return (e,stmts) + return e end end e.args = [Any[e.args[2]]; newargs...] @@ -4468,14 +4483,14 @@ function inlining_pass(e::Expr, sv::InferenceState) else f = nothing if !( isleaftype(ft) || ft<:Type ) - return (e,stmts) + return e end end else - return (e,stmts) + return e end end - return (e,stmts) + return e end const compiler_temp_sym = Symbol("#temp#") @@ -4576,7 +4591,8 @@ normslot(s::TypedSlot) = SlotNumber(slot_id(s)) function get_replacement(table, var::Union{SlotNumber, SSAValue}, init::ANY, nargs, slottypes, ssavaluetypes) #if isa(init, QuoteNode) # this can cause slight code size increases # return init - if isa(init, Expr) && init.head === :static_parameter + if (isa(init, Expr) && init.head === :static_parameter) || isa(init, corenumtype) || + init === () || init === nothing return init elseif isa(init, Slot) && is_argument(nargs, init::Slot) # the transformation is not ideal if the assignment