diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 1dd57ddacfe41..9588e2160f62a 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -166,11 +166,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # if there's a possibility we could constant-propagate a better result # (hopefully without doing too much work), try to do that now # TODO: it feels like this could be better integrated into abstract_call_method / typeinf_edge - const_rettype = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle) + const_rettype, result = abstract_call_method_with_const_args(interp, rettype, f, argtypes, applicable[nonbot]::MethodMatch, sv, edgecycle) if const_rettype ⊑ rettype # use the better result, if it's a refinement of rettype rettype = const_rettype end + if result !== nothing + info = ConstCallInfo(info, result) + end end if is_unused && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") @@ -263,7 +266,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp method = match.method nargs::Int = method.nargs method.isva && (nargs -= 1) - length(argtypes) >= nargs || return Any + length(argtypes) >= nargs || return Any, nothing haveconst = false allconst = true # see if any or all of the arguments are constant and propagating constants may be worthwhile @@ -279,21 +282,21 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp break end end - haveconst || improvable_via_constant_propagation(rettype) || return Any + haveconst || improvable_via_constant_propagation(rettype) || return Any, nothing force_inference = method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation if !force_inference && nargs > 1 if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] # don't propagate constant index into indexing of non-constant array if arrty isa Type && arrty <: AbstractArray && !issingletontype(arrty) - return Any + return Any, nothing elseif arrty ⊑ Array - return Any + return Any, nothing end elseif istopfunction(f, :iterate) itrty = argtypes[2] if itrty ⊑ Array - return Any + return Any, nothing end end end @@ -304,7 +307,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp istopfunction(f, :<<) || istopfunction(f, :>>)) # it is almost useless to inline the op of when all the same type, # but highly worthwhile to inline promote of a constant - length(argtypes) > 2 || return Any + length(argtypes) > 2 || return Any, nothing t1 = widenconst(argtypes[2]) all_same = true for i in 3:length(argtypes) @@ -313,18 +316,18 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp break end end - all_same && return Any + all_same && return Any, nothing end if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!) force_inference = true end force_inference |= allconst mi = specialize_method(match, !force_inference) - mi === nothing && return Any + mi === nothing && return Any, nothing mi = mi::MethodInstance # decide if it's likely to be worthwhile if !force_inference && !const_prop_heuristic(interp, method, mi) - return Any + return Any, nothing end inf_cache = get_inference_cache(interp) inf_result = cache_lookup(mi, argtypes, inf_cache) @@ -336,7 +339,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp cyclei = 0 while !(infstate === nothing) if method === infstate.linfo.def && any(infstate.result.overridden_by_const) - return Any + return Any, nothing end if cyclei < length(infstate.callers_in_cycle) cyclei += 1 @@ -349,16 +352,16 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp end inf_result = InferenceResult(mi, argtypes) frame = InferenceState(inf_result, #=cache=#false, interp) - frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it + frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it frame.parent = sv push!(inf_cache, inf_result) - typeinf(interp, frame) || return Any + typeinf(interp, frame) || return Any, nothing end result = inf_result.result # if constant inference hits a cycle, just bail out - isa(result, InferenceState) && return Any + isa(result, InferenceState) && return Any, nothing add_backedge!(inf_result.linfo, sv) - return result + return result, inf_result end const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result." diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 565eecd56416a..e9670a47af122 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -33,7 +33,7 @@ end pass to apply its own inlining policy decisions. """ struct DelayedInliningSpec - match::MethodMatch + match::Union{MethodMatch, InferenceResult} atypes::Vector{Any} stmttype::Any end @@ -44,7 +44,11 @@ struct InliningTodo spec::Union{ResolvedInliningSpec, DelayedInliningSpec} end -InliningTodo(mi::MethodInstance, match::MethodMatch, atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype)) +InliningTodo(mi::MethodInstance, match::MethodMatch, + atypes::Vector{Any}, @nospecialize(stmttype)) = InliningTodo(mi, DelayedInliningSpec(match, atypes, stmttype)) + +InliningTodo(result::InferenceResult, atypes::Vector{Any}, @nospecialize(stmttype)) = + InliningTodo(result.linfo, DelayedInliningSpec(result, atypes, stmttype)) struct ConstantCase val::Any @@ -631,7 +635,10 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx:: new_stmt = Expr(:call, argexprs[2], def, state...) state1 = insert_node!(ir, idx, call.rt, new_stmt) new_sig = with_atype(call_sig(ir, new_stmt)::Signature) - if isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo) + if isa(call.info, ConstCallInfo) + handle_const_call!(ir, state1.id, new_stmt, call.info, new_sig, + call.rt, et, caches, false, todo) + elseif isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo) info = isa(call.info, MethodMatchInfo) ? MethodMatchInfo[call.info] : call.info.matches # See if we can inline this call to `iterate` @@ -676,9 +683,32 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::Meth return mi end +function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::InferenceResult) + mi = specialize_method(result.linfo.def, result.linfo.specTypes, + result.linfo.sparam_vals, false, true) + mi !== nothing && et !== nothing && push!(et, mi::MethodInstance) + return mi +end + function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, caches::InferenceCaches) spec = todo.spec::DelayedInliningSpec - isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype) + + #XXX: update_valid_age!(min_valid[1], max_valid[1], sv) + isconst, src = false, nothing + if isa(spec.match, InferenceResult) + let inferred_src = spec.match.src + if isa(inferred_src, CodeInfo) + isconst, src = false, inferred_src + elseif isa(inferred_src, Const) + if !is_inlineable_constant(inferred_src.val) + return compileable_specialization(et, spec.match) + end + isconst, src = true, quoted(inferred_src.val) + end + end + else + isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype) + end if isconst && et !== nothing push!(et, todo.mi) @@ -717,6 +747,13 @@ function resolve_todo!(todo::Vector{Pair{Int, Any}}, et::Union{EdgeTracker, Noth todo end +function validate_sparams(sparams::SimpleVector) + for i = 1:length(sparams) + (isa(sparams[i], TypeVar) || isa(sparams[i], Core.TypeofVararg)) && return false + end + return true +end + function analyze_method!(match::MethodMatch, atypes::Vector{Any}, et::Union{EdgeTracker, Nothing}, caches::Union{InferenceCaches, Nothing}, @@ -737,9 +774,8 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any}, # Bail out if any static parameters are left as TypeVar ok = true - for i = 1:length(match.sparams) - (isa(match.sparams[i], TypeVar) || isa(match.sparams[i], Core.TypeofVararg)) && return nothing - end + validate_sparams(match.sparams) || return nothing + if !params.inlining return compileable_specialization(et, match) @@ -1146,6 +1182,28 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int return nothing end +function handle_const_call!(ir::IRCode, idx::Int, stmt::Expr, + info::ConstCallInfo, sig::Signature, @nospecialize(calltype), + et::Union{EdgeTracker, Nothing}, caches::Union{InferenceCaches, Nothing}, + isinvoke::Bool, todo::Vector{Pair{Int, Any}}) + item = InliningTodo(info.result, sig.atypes, calltype) + validate_sparams(item.mi.sparam_vals) || return + mthd_sig = item.mi.def.sig + mistypes = item.mi.specTypes + caches !== nothing && (item = resolve_todo(item, et, caches)) + if sig.atype <: mthd_sig + return handle_single_case!(ir, stmt, idx, item, isinvoke, todo) + else + item === nothing && return + # Union split out the error case + item = UnionSplit(false, sig.atype, Pair{Any, Any}[mistypes => item]) + if isinvoke + stmt.args = rewrite_invoke_exprargs!(stmt.args) + end + push!(todo, idx=>item) + end +end + function assemble_inline_todo!(ir::IRCode, state::InliningState) # todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie) todo = Pair{Int, Any}[] @@ -1173,6 +1231,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) end end + # If inference arrived at this result by using constant propagation, + # it'll performed a specialized analysis for just this case. Use its + # result. + if isa(info, ConstCallInfo) + handle_const_call!(ir, idx, stmt, info, sig, calltype, state.et, + state.caches, invoke_data !== nothing, todo) + continue + end + # Ok, now figure out what method to call if invoke_data !== nothing inline_invoke!(ir, idx, sig, invoke_data, state, todo) @@ -1387,35 +1454,6 @@ function ssa_substitute_op!(@nospecialize(val), arg_replacements::Vector{Any}, end function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::InferenceCaches, @nospecialize(rettype)) - if caches.inf_cache !== nothing - # see if the method has a InferenceResult in the current cache - # or an existing inferred code info store in `.inferred` - haveconst = false - for i in 1:length(atypes) - if has_nontrivial_const_info(atypes[i]) - # have new information from argtypes that wasn't available from the signature - haveconst = true - break - end - end - if haveconst || improvable_via_constant_propagation(rettype) - inf_result = cache_lookup(mi, atypes, caches.inf_cache) # Union{Nothing, InferenceResult} - else - inf_result = nothing - end - #XXX: update_valid_age!(min_valid[1], max_valid[1], sv) - if isa(inf_result, InferenceResult) - let inferred_src = inf_result.src - if isa(inferred_src, CodeInfo) - return svec(false, inferred_src) - end - if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val) - return svec(true, quoted(inferred_src.val),) - end - end - end - end - linfo = get(caches.mi_cache, mi, nothing) if linfo isa CodeInstance if invoke_api(linfo) == 2 diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index 0222a2343f985..762325c0c9579 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -82,6 +82,18 @@ struct UnionSplitApplyCallInfo infos::Vector{ApplyCallInfo} end +""" + struct ConstCallInfo + +Precision for this call was improved using constant information. This info +keeps a reference to the result that was used (or created for these) +constant information. +""" +struct ConstCallInfo + call::Any + result::InferenceResult +end + # Stmt infos that are used by external consumers, but not by optimization. # These are not produced by default and must be explicitly opted into by # the AbstractInterpreter. diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index eb140008fadcf..f290608f5b8ad 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -234,7 +234,7 @@ function unswitchtupleunion(u::Union) ts = uniontypes(u) n = -1 for t in ts - if t isa DataType && t.name === Tuple.name && !isvarargtype(t.parameters[end]) + if t isa DataType && t.name === Tuple.name && length(t.parameters) != 0 && !isvarargtype(t.parameters[end]) if n == -1 n = length(t.parameters) elseif n != length(t.parameters)