diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index b126e4dfd79f6e..3e9f38c9ad98b2 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype)) end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), - fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype), + (; fargs, argtypes)::ArgInfo, @nospecialize(atype), sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv)) add_remark!(interp, sv, "Skipped call in throw block") @@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), push!(edges, edge) end this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] - const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + arginfo = ArgInfo(fargs, this_argtypes) + const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false) if const_result !== nothing const_rt, const_result = const_result if const_rt !== rt && const_rt ⊑ rt @@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] - const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + arginfo = ArgInfo(fargs, this_argtypes) + const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false) if const_result !== nothing const_this_rt, const_result = const_result if const_this_rt !== this_rt && const_this_rt ⊑ this_rt @@ -523,13 +525,13 @@ struct MethodCallResult end function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult, - @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + @nospecialize(f), arginfo::ArgInfo, match::MethodMatch, sv::InferenceState, va_override::Bool) - mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv) + mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv) mi === nothing && return nothing # try constant prop' inf_cache = get_inference_cache(interp) - inf_result = cache_lookup(mi, argtypes, inf_cache) + inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache) if inf_result === nothing # if there might be a cycle, check to make sure we don't end up # calling ourselves here. @@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul return nothing end end - inf_result = InferenceResult(mi, argtypes, va_override) + inf_result = InferenceResult(mi, arginfo, va_override) if !any(inf_result.overridden_by_const) add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") return nothing @@ -565,7 +567,7 @@ end # if there's a possibility we could get a better result (hopefully without doing too much work) # returns `MethodInstance` with constant arguments, returns nothing otherwise function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult, - @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + @nospecialize(f), arginfo::ArgInfo, match::MethodMatch, sv::InferenceState) if !InferenceParams(interp).ipo_constant_propagation add_remark!(interp, sv, "[constprop] Disabled by parameter") @@ -580,26 +582,24 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me force || const_prop_entry_heuristic(interp, result, sv) || return nothing nargs::Int = method.nargs method.isva && (nargs -= 1) - length(argtypes) < nargs && return nothing - if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt)) + length(arginfo.argtypes) < nargs && return nothing + if !const_prop_argument_heuristic(interp, arginfo, sv) add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics") return nothing end - allconst = is_allconst(argtypes) - if !force - if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) - add_remark!(interp, sv, "[constprop] Disabled by function heuristic") - return nothing - end + all_overridden = is_all_overridden(arginfo) + if !force && !const_prop_function_heuristic(interp, f, arginfo, nargs, all_overridden, sv) + add_remark!(interp, sv, "[constprop] Disabled by function heuristic") + return nothing end - force |= allconst + force |= all_overridden mi = specialize_method(match; preexisting=!force) if mi === nothing add_remark!(interp, sv, "[constprop] Failed to specialize") return nothing end mi = mi::MethodInstance - if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv) + if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv) add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic") return nothing end @@ -616,12 +616,16 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC return false end -# see if propagating constants may be worthwhile -function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any}) - for a in argtypes - a = widenconditional(a) - if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) - return true +# determines heuristically whether if constant propagation can be worthwhile +# by checking if any of given `argtypes` is "interesting" enough to be propagated +function const_prop_argument_heuristic(_::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, _::InferenceState) + for i in 1:length(argtypes) + a = argtypes[i] + if isa(a, Conditional) && fargs !== nothing + is_const_prop_profitable_conditional(a, fargs) && return true + else + a = widenconditional(a) + has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) && return true end end return false @@ -642,15 +646,32 @@ function is_const_prop_profitable_arg(@nospecialize(arg)) return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) end -function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype)) - return improvable_via_constant_propagation(rettype) +function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any}) + slotid = find_constrained_arg(cnd, fargs) + if slotid !== nothing + return true + end + # as a minor optimization, we just check the result is a constant or not, + # since both `has_nontrivial_const_info`/`is_const_prop_profitable_arg` return `true` + # for `Const(::Bool)` + return isa(widenconditional(cnd), Const) end -function is_allconst(argtypes::Vector{Any}) +function find_constrained_arg(cnd::Conditional, fargs::Vector{Any}) + slot = cnd.var + return findfirst(fargs) do @nospecialize(x) + x === slot + end +end + +# checks if all argtypes has additional information other than what `Type` can provide +function is_all_overridden((; fargs, argtypes)::ArgInfo) for a in argtypes - a = widenconditional(a) - if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque) - return false + if isa(a, Conditional) && fargs !== nothing + is_const_prop_profitable_conditional(a, fargs) || return false + else + a = widenconditional(a) + is_forwardable_argtype(a) || return false end end return true @@ -663,7 +684,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: istopfunction(f, :setproperty!) end -function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool) +function const_prop_function_heuristic( + interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo, + nargs::Int, all_overridden::Bool, _::InferenceState) if nargs > 1 if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] @@ -680,7 +703,7 @@ function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecializ end end end - if !allconst && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || + if !all_overridden && (istopfunction(f, :+) || istopfunction(f, :-) || istopfunction(f, :*) || istopfunction(f, :(==)) || istopfunction(f, :!=) || istopfunction(f, :<=) || istopfunction(f, :>=) || istopfunction(f, :<) || istopfunction(f, :>) || istopfunction(f, :<<) || istopfunction(f, :>>)) @@ -705,7 +728,7 @@ end # result anyway. function const_prop_methodinstance_heuristic( interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance, - argtypes::Vector{Any}, sv::InferenceState) + (; argtypes)::ArgInfo, sv::InferenceState) method = match.method if method.is_for_opaque_closure # Not inlining an opaque closure can be very expensive, so be generous @@ -832,7 +855,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n return Any[Vararg{Any}], nothing end @assert !isvarargtype(itertype) - call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv) + call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv) stateordonet = call.rt info = call.info # Return Bottom if this is not an iterator. @@ -866,7 +889,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n valtype = getfield_tfunc(stateordonet, Const(1)) push!(ret, valtype) statetype = nstatetype - call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv) + call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv) stateordonet = call.rt stateordonet_widened = widenconst(stateordonet) push!(calls, call) @@ -901,7 +924,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n end valtype = tmerge(valtype, nounion.parameters[1]) statetype = tmerge(statetype, nounion.parameters[2]) - stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt + stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt stateordonet_widened = widenconst(stateordonet) end if valtype !== Union{} @@ -990,7 +1013,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:: break end end - call = abstract_call(interp, nothing, ct, sv, max_methods) + call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods) push!(retinfos, ApplyCallInfo(call.info, arginfo)) res = tmerge(res, call.rt) if bail_out_apply(interp, res, sv) @@ -1054,8 +1077,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int) return argtypes[i:n] end -function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}}, - argtypes::Vector{Any}, sv::InferenceState, max_methods::Int) +function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo, + sv::InferenceState, max_methods::Int) @nospecialize f la = length(argtypes) if f === ifelse && fargs isa Vector{Any} && la == 4 @@ -1188,7 +1211,7 @@ function abstract_call_unionall(argtypes::Vector{Any}) return Any end -function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) +function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState) ft′ = argtype_by_index(argtypes, 2) ft = widenconst(ft′) ft === Bottom && return CallMeta(Bottom, false) @@ -1215,15 +1238,18 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv: # try constant propagation with manual inlinings of some of the heuristics # since some checks within `abstract_call_method_with_const_args` seem a bit costly const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing)) - argtypes′ = argtypes[4:end] - const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing)) - pushfirst!(argtypes′, ft) + argtypes′ = argtypes[3:end] + argtypes′[1] = ft + fargs′ = fargs[3:end] + fargs′[1] = fargs[1] + arginfo = ArgInfo(fargs′, argtypes′) + const_prop_argument_heuristic(interp, arginfo, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing)) # # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons # for i in 1:length(argtypes′) # t, a = ti.parameters[i], argtypes′[i] # argtypes′[i] = t ⊑ a ? t : a # end - const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false) + const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false) if const_result !== nothing const_rt, const_result = const_result if const_rt !== rt && const_rt ⊑ rt @@ -1235,21 +1261,20 @@ end # call where the function is known exactly function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), - fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, - sv::InferenceState, + arginfo::ArgInfo, sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) - + (; fargs, argtypes) = arginfo la = length(argtypes) if isa(f, Builtin) if f === _apply_iterate return abstract_apply(interp, argtypes, sv, max_methods) elseif f === invoke - return abstract_invoke(interp, argtypes, sv) + return abstract_invoke(interp, arginfo, sv) elseif f === modifyfield! return abstract_modifyfield!(interp, argtypes, sv) end - return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false) + return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false) elseif f === Core.kwfunc if la == 2 ft = widenconst(argtypes[2]) @@ -1282,12 +1307,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), # handle Conditional propagation through !Bool aty = argtypes[2] if isa(aty, Conditional) - call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` + call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info) end elseif la == 3 && istopfunction(f, :!==) # mark !== as exactly a negated call to === - rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt + rty = abstract_call_known(interp, (===), arginfo, sv).rt if isa(rty, Conditional) return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else elseif isa(rty, Const) @@ -1303,7 +1328,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), fargs = nothing end argtypes = Any[typeof(<:), argtypes[3], argtypes[2]] - return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false) + return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false) elseif la == 2 && (a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) && istopfunction(f, :length) @@ -1326,7 +1351,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(val === false ? Type : val, MethodResultPure()) end atype = argtypes_to_type(argtypes) - return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods) + return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods) end function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState) @@ -1339,8 +1364,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) info = OpaqueClosureCallInfo(match) if !result.edgecycle - const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes, - match, sv, closure.isva) + const_result = abstract_call_method_with_const_args(interp, result, closure, + ArgInfo(nothing, argtypes), match, sv, closure.isva) if const_result !== nothing const_rettype, const_result = const_result if const_rettype ⊑ rt @@ -1363,9 +1388,9 @@ function most_general_argtypes(closure::PartialOpaque) end # call where the function is any lattice element -function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, +function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) - #print("call ", e.args[1], argtypes, "\n\n") + argtypes = arginfo.argtypes ft = argtypes[1] f = singleton_type(ft) if isa(ft, PartialOpaque) @@ -1379,9 +1404,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ add_remark!(interp, sv, "Could not identify method table for call") return CallMeta(Any, false) end - return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods) + return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods) end - return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods) + return abstract_call_known(interp, f, arginfo, sv, max_methods) end function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool) @@ -1428,7 +1453,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V # this may be the wrong world for the call, # but some of the result is likely to be valid anyways # and that may help generate better codegen - abstract_call(interp, nothing, at, sv) + abstract_call(interp, ArgInfo(nothing, at), sv) nothing end @@ -1502,7 +1527,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if argtypes === nothing t = Bottom else - callinfo = abstract_call(interp, ea, argtypes, sv) + callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv) sv.stmt_info[sv.currpc] = callinfo.info t = callinfo.rt end diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index a24588ee8f6ab1..978eabc9f02098 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -3,20 +3,57 @@ function is_argtype_match(@nospecialize(given_argtype), @nospecialize(cache_argtype), overridden_by_const::Bool) - if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || isa(given_argtype, PartialOpaque) + if is_forwardable_argtype(given_argtype) return is_lattice_equal(given_argtype, cache_argtype) end return !overridden_by_const end +function is_forwardable_argtype(@nospecialize x) + return isa(x, Const) || + isa(x, Conditional) || + isa(x, PartialStruct) || + isa(x, PartialOpaque) +end + # In theory, there could be a `cache` containing a matching `InferenceResult` # for the provided `linfo` and `given_argtypes`. The purpose of this function is # to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`, # so that we can construct cache-correct `InferenceResult`s in the first place. -function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool) +function matching_cache_argtypes( + linfo::MethodInstance, (; fargs, argtypes)::ArgInfo, va_override::Bool) @assert isa(linfo.def, Method) # ensure the next line works nargs::Int = linfo.def.nargs - given_argtypes = anymap(widenconditional, given_argtypes) + cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override) + given_argtypes = Vector{Any}(undef, length(argtypes)) + local condargs = nothing + for i in 1:length(argtypes) + argtype = argtypes[i] + # forward `Conditional` if it conveys a constraint on any other argument + if isa(argtype, Conditional) && fargs !== nothing + cnd = argtype + slotid = find_constrained_arg(cnd, fargs) + if slotid !== nothing + # using union-split signature, we may be able to narrow down `Conditional` + sigt = widenconst(slotid > nargs ? argtypes[slotid] : cache_argtypes[slotid]) + vtype = tmeet(cnd.vtype, sigt) + elsetype = tmeet(cnd.elsetype, sigt) + if vtype === Bottom && elsetype === Bottom + # we accidentally proved this method match is impossible + # TODO bail out here immediately rather than just propagating Bottom ? + given_argtypes[i] = Bottom + else + if condargs === nothing + condargs = Tuple{Int,Int}[] + end + push!(condargs, (slotid, i)) + given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype) + end + continue + end + end + given_argtypes[i] = widenconditional(argtype) + end isva = va_override || linfo.def.isva if isva || isvarargtype(given_argtypes[end]) isva_given_argtypes = Vector{Any}(undef, nargs) @@ -30,15 +67,22 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, last = nargs end isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end]) + # invalidate `Conditional` imposed on varargs + if condargs !== nothing + for (slotid, i) in condargs + if slotid ≥ last + isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i]) + end + end + end end given_argtypes = isva_given_argtypes end @assert length(given_argtypes) == nargs - cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override) for i in 1:nargs given_argtype = given_argtypes[i] cache_argtype = cache_argtypes[i] - if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i]) + if !is_argtype_match(given_argtype, cache_argtype, false) # prefer the argtype we were given over the one computed from `linfo` cache_argtypes[i] = given_argtype overridden_by_const[i] = true diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 1d25e9fa3795ae..2fe6366f68b965 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -964,7 +964,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any v = unwrapva(argtypes[5]) TF = getfield_tfunc(o, f) push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call - callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1) + callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), sv, #=max_methods=# 1) pop!(sv.ssavalue_uses[sv.currpc], sv.currpc) TF2 = tmeet(callinfo.rt, widenconst(TF)) if TF2 === Bottom @@ -1747,7 +1747,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s if contains_is(argtypes_vec, Union{}) return CallMeta(Const(Union{}), false) end - call = abstract_call(interp, nothing, argtypes_vec, sv, -1) + call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), sv, -1) info = verbose_stmt_info(interp) ? ReturnTypeCallInfo(call.info) : false rt = widenconditional(call.rt) if isa(rt, Const) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index f9d88fc2ac2609..728b0d8206ea7d 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -17,6 +17,11 @@ If `interp` is an `AbstractInterpreter`, it is expected to provide at least the """ abstract type AbstractInterpreter end +struct ArgInfo + fargs::Union{Nothing,Vector{Any}} + argtypes::Vector{Any} +end + """ InferenceResult @@ -29,8 +34,10 @@ mutable struct InferenceResult result # ::Type, or InferenceState if WIP src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available valid_worlds::WorldRange # if inference and optimization is finished - function InferenceResult(linfo::MethodInstance, given_argtypes = nothing, va_override=false) - argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes, va_override) + function InferenceResult(linfo::MethodInstance, + arginfo::Union{Nothing,ArgInfo} = nothing, + va_override::Bool = false) + argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override) return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange()) end end diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index 62c146a80d9695..52b4938d26e144 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -259,15 +259,6 @@ unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity unioncomplexity(t::TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T)::Int : 0 unioncomplexity(@nospecialize(x)) = 0 -function improvable_via_constant_propagation(@nospecialize(t)) - if isconcretetype(t) && t <: Tuple - for p in t.parameters - p === DataType && return true - end - end - return false -end - # convert a Union of Tuple types to a Tuple of Unions function unswitchtupleunion(u::Union) ts = uniontypes(u) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 2a7f54f9832e22..72791bfaa0a28b 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2000,6 +2000,61 @@ function _g_ifelse_isa_() end @test Base.return_types(_g_ifelse_isa_, ()) == [Int] +@testset "Conditional forwarding" begin + # forward `Conditional` if it conveys a constraint on any other argument + ifelselike(cnd, x, y) = cnd ? x : y + + @test Base.return_types((Any,Int,)) do x, y + ifelselike(isa(x, Int), x, y) + end |> only == Int + + # should work nicely with union-split + @test Base.return_types((Union{Int,Nothing},)) do x + ifelselike(isa(x, Int), x, 0) + end |> only == Int + + @test Base.return_types((Any,Int)) do x, y + ifelselike(!isa(x, Int), y, x) + end |> only == Int + + @test Base.return_types((Any,Int)) do x, y + a = ifelselike(x === 0, x, 0) # ::Const(0) + if a == 0 + return y + else + return nothing # dead branch + end + end |> only == Int + + # pick up the first if there are multiple constrained arguments + @test Base.return_types((Any,)) do x + ifelselike(isa(x, Int), x, x) + end |> only == Any + + # just propagate multiple constraints + ifelselike2(cnd1, cnd2, x, y, z) = cnd1 ? x : cnd2 ? y : z + @test Base.return_types((Any,Any)) do x, y + ifelselike2(isa(x, Int), isa(y, Int), x, y, 0) + end |> only == Int + + # work with `invoke` + @test Base.return_types((Any,Any)) do x, y + Base.@invoke ifelselike(isa(x, Int), x, y::Int) + end |> only == Int + + # don't be confused with vararg method + vacond(cnd, va...) = cnd ? va : 0 + @test Base.return_types((Any,)) do x + # at runtime we will see `va::Tuple{Tuple{Int,Int}, Tuple{Int,Int}}` + vacond(isa(x, Tuple{Int,Int}), x, x) + end |> only == Union{Int,Tuple{Any,Any}} + + # demonstrate extra constraint propagation for Base.ifelse + @test Base.return_types((Any,Int,)) do x, y + ifelse(isa(x, Int), x, y) + end |> only == Int +end + # Equivalence of Const(T.instance) and T for singleton types @test Const(nothing) ⊑ Nothing && Nothing ⊑ Const(nothing)