diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 1e95bcfb11f44b..b7c822ae1164b3 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -29,7 +29,7 @@ function is_improvable(@nospecialize(rtype)) end function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), - fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype), + (; fargs, argtypes)::ArgInfo, @nospecialize(atype), sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv)) add_remark!(interp, sv, "Skipped call in throw block") @@ -85,7 +85,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), push!(edges, edge) end this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] - const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + arginfo = ArgInfo(fargs, this_argtypes) + const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false) if const_result !== nothing const_rt, const_result = const_result if const_rt !== rt && const_rt ⊑ rt @@ -110,7 +111,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] - const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false) + arginfo = ArgInfo(fargs, this_argtypes) + const_result = abstract_call_method_with_const_args(interp, result, f, arginfo, match, sv, false) if const_result !== nothing const_this_rt, const_result = const_result if const_this_rt !== this_rt && const_this_rt ⊑ this_rt @@ -523,13 +525,13 @@ struct MethodCallResult end function abstract_call_method_with_const_args(interp::AbstractInterpreter, result::MethodCallResult, - @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + @nospecialize(f), arginfo::ArgInfo, match::MethodMatch, sv::InferenceState, va_override::Bool) - mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv) + mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv) mi === nothing && return nothing # try constant prop' inf_cache = get_inference_cache(interp) - inf_result = cache_lookup(mi, argtypes, inf_cache) + inf_result = cache_lookup(mi, arginfo.argtypes, inf_cache) if inf_result === nothing # if there might be a cycle, check to make sure we don't end up # calling ourselves here. @@ -545,7 +547,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul return nothing end end - inf_result = InferenceResult(mi, argtypes, va_override) + inf_result = InferenceResult(mi; arginfo, va_override) if !any(inf_result.overridden_by_const) add_remark!(interp, sv, "[constprop] Could not handle constant info in matching_cache_argtypes") return nothing @@ -565,7 +567,7 @@ end # if there's a possibility we could get a better result (hopefully without doing too much work) # returns `MethodInstance` with constant arguments, returns nothing otherwise function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::MethodCallResult, - @nospecialize(f), argtypes::Vector{Any}, match::MethodMatch, + @nospecialize(f), arginfo::ArgInfo, match::MethodMatch, sv::InferenceState) if !InferenceParams(interp).ipo_constant_propagation add_remark!(interp, sv, "[constprop] Disabled by parameter") @@ -580,14 +582,14 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me force || const_prop_entry_heuristic(interp, result, sv) || return nothing nargs::Int = method.nargs method.isva && (nargs -= 1) - length(argtypes) < nargs && return nothing - if !(const_prop_argument_heuristic(interp, argtypes) || const_prop_rettype_heuristic(interp, result.rt)) + length(arginfo.argtypes) < nargs && return nothing + if !(const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, result.rt)) add_remark!(interp, sv, "[constprop] Disabled by argument and rettype heuristics") return nothing end - allconst = is_allconst(argtypes) + allconst = is_allconst(arginfo) if !force - if !const_prop_function_heuristic(interp, f, argtypes, nargs, allconst) + if !const_prop_function_heuristic(interp, f, arginfo, nargs, allconst) add_remark!(interp, sv, "[constprop] Disabled by function heuristic") return nothing end @@ -599,7 +601,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me return nothing end mi = mi::MethodInstance - if !force && !const_prop_methodinstance_heuristic(interp, match, mi, argtypes, sv) + if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv) add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic") return nothing end @@ -617,8 +619,11 @@ function const_prop_entry_heuristic(interp::AbstractInterpreter, result::MethodC end # see if propagating constants may be worthwhile -function const_prop_argument_heuristic(interp::AbstractInterpreter, argtypes::Vector{Any}) +function const_prop_argument_heuristic(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo) for a in argtypes + if isa(a, Conditional) && fargs !== nothing + return is_const_prop_profitable_conditional(a, fargs) + end a = widenconditional(a) if has_nontrivial_const_info(a) && is_const_prop_profitable_arg(a) return true @@ -642,13 +647,34 @@ function is_const_prop_profitable_arg(@nospecialize(arg)) return isa(val, Symbol) || isa(val, Type) || (!isa(val, String) && !ismutable(val)) end +function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any}) + slotid = find_constrained_arg(cnd, fargs) + if slotid !== nothing + return true + end + return is_const_prop_profitable_arg(widenconditional(cnd)) +end + +function find_constrained_arg(cnd::Conditional, fargs::Vector{Any}) + slot = cnd.var + return findfirst(fargs) do @nospecialize(x) + x === slot + end +end + function const_prop_rettype_heuristic(interp::AbstractInterpreter, @nospecialize(rettype)) return improvable_via_constant_propagation(rettype) end -function is_allconst(argtypes::Vector{Any}) +function is_allconst((; fargs, argtypes)::ArgInfo) for a in argtypes + if isa(a, Conditional) && fargs !== nothing + if is_const_prop_profitable_conditional(a, fargs) + continue + end + end a = widenconditional(a) + # TODO unify these condition with `has_nontrivial_const_info` if !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque) return false end @@ -663,7 +689,9 @@ function force_const_prop(interp::AbstractInterpreter, @nospecialize(f), method: istopfunction(f, :setproperty!) end -function const_prop_function_heuristic(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, nargs::Int, allconst::Bool) +function const_prop_function_heuristic( + interp::AbstractInterpreter, @nospecialize(f), (; argtypes)::ArgInfo, + nargs::Int, allconst::Bool) if nargs > 1 if istopfunction(f, :getindex) || istopfunction(f, :setindex!) arrty = argtypes[2] @@ -705,7 +733,7 @@ end # result anyway. function const_prop_methodinstance_heuristic( interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance, - argtypes::Vector{Any}, sv::InferenceState) + (; argtypes)::ArgInfo, sv::InferenceState) method = match.method if method.is_for_opaque_closure # Not inlining an opaque closure can be very expensive, so be generous @@ -835,7 +863,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n return Any[Vararg{Any}], nothing end @assert !isvarargtype(itertype) - call = abstract_call_known(interp, iteratef, nothing, Any[itft, itertype], sv) + call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[itft, itertype]), sv) stateordonet = call.rt info = call.info # Return Bottom if this is not an iterator. @@ -869,7 +897,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n valtype = getfield_tfunc(stateordonet, Const(1)) push!(ret, valtype) statetype = nstatetype - call = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv) + call = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv) stateordonet = call.rt stateordonet_widened = widenconst(stateordonet) push!(calls, call) @@ -904,7 +932,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n end valtype = tmerge(valtype, nounion.parameters[1]) statetype = tmerge(statetype, nounion.parameters[2]) - stateordonet = abstract_call_known(interp, iteratef, nothing, Any[Const(iteratef), itertype, statetype], sv).rt + stateordonet = abstract_call_known(interp, iteratef, ArgInfo(nothing, Any[Const(iteratef), itertype, statetype]), sv).rt stateordonet_widened = widenconst(stateordonet) end if valtype !== Union{} @@ -993,7 +1021,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:: break end end - call = abstract_call(interp, nothing, ct, sv, max_methods) + call = abstract_call(interp, ArgInfo(nothing, ct), sv, max_methods) push!(retinfos, ApplyCallInfo(call.info, arginfo)) res = tmerge(res, call.rt) if bail_out_apply(interp, res, sv) @@ -1057,8 +1085,8 @@ function argtype_tail(argtypes::Vector{Any}, i::Int) return argtypes[i:n] end -function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}}, - argtypes::Vector{Any}, sv::InferenceState, max_methods::Int) +function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs, argtypes)::ArgInfo, + sv::InferenceState, max_methods::Int) @nospecialize f la = length(argtypes) if f === ifelse && fargs isa Vector{Any} && la == 4 @@ -1190,7 +1218,7 @@ function abstract_call_unionall(argtypes::Vector{Any}) return Any end -function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) +function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, sv::InferenceState) ft′ = argtype_by_index(argtypes, 2) ft = widenconst(ft′) ft === Bottom && return CallMeta(Bottom, false) @@ -1218,14 +1246,17 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv: # since some checks within `abstract_call_method_with_const_args` seem a bit costly const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing)) argtypes′ = argtypes[4:end] - const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing)) pushfirst!(argtypes′, ft) + fargs′ = fargs[4:end] + pushfirst!(fargs′, fargs[1]) + arginfo = ArgInfo(fargs′, argtypes′) + const_prop_argument_heuristic(interp, arginfo) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing)) # # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons # for i in 1:length(argtypes′) # t, a = ti.parameters[i], argtypes′[i] # argtypes′[i] = t ⊑ a ? t : a # end - const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), argtypes′, match, sv, false) + const_result = abstract_call_method_with_const_args(interp, result, singleton_type(ft′), arginfo, match, sv, false) if const_result !== nothing const_rt, const_result = const_result if const_rt !== rt && const_rt ⊑ rt @@ -1237,21 +1268,20 @@ end # call where the function is known exactly function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), - fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, - sv::InferenceState, + arginfo::ArgInfo, sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) - + (; fargs, argtypes) = arginfo la = length(argtypes) if isa(f, Builtin) if f === _apply_iterate return abstract_apply(interp, argtypes, sv, max_methods) elseif f === invoke - return abstract_invoke(interp, argtypes, sv) + return abstract_invoke(interp, arginfo, sv) elseif f === modifyfield! return abstract_modifyfield!(interp, argtypes, sv) end - return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false) + return CallMeta(abstract_call_builtin(interp, f, arginfo, sv, max_methods), false) elseif f === Core.kwfunc if la == 2 ft = widenconst(argtypes[2]) @@ -1284,12 +1314,12 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), # handle Conditional propagation through !Bool aty = argtypes[2] if isa(aty, Conditional) - call = abstract_call_gf_by_type(interp, f, fargs, Any[Const(f), Bool], Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` + call = abstract_call_gf_by_type(interp, f, ArgInfo(fargs, Any[Const(f), Bool]), Tuple{typeof(f), Bool}, sv) # make sure we've inferred `!(::Bool)` return CallMeta(Conditional(aty.var, aty.elsetype, aty.vtype), call.info) end elseif la == 3 && istopfunction(f, :!==) # mark !== as exactly a negated call to === - rty = abstract_call_known(interp, (===), fargs, argtypes, sv).rt + rty = abstract_call_known(interp, (===), arginfo, sv).rt if isa(rty, Conditional) return CallMeta(Conditional(rty.var, rty.elsetype, rty.vtype), false) # swap if-else elseif isa(rty, Const) @@ -1305,7 +1335,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), fargs = nothing end argtypes = Any[typeof(<:), argtypes[3], argtypes[2]] - return CallMeta(abstract_call_known(interp, <:, fargs, argtypes, sv).rt, false) + return CallMeta(abstract_call_known(interp, <:, ArgInfo(fargs, argtypes), sv).rt, false) elseif la == 2 && (a2 = argtypes[2]; isa(a2, Const)) && (svecval = a2.val; isa(svecval, SimpleVector)) && istopfunction(f, :length) @@ -1328,7 +1358,7 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), return CallMeta(val === false ? Type : val, MethodResultPure()) end atype = argtypes_to_type(argtypes) - return abstract_call_gf_by_type(interp, f, fargs, argtypes, atype, sv, max_methods) + return abstract_call_gf_by_type(interp, f, arginfo, atype, sv, max_methods) end function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState) @@ -1341,8 +1371,8 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::Part match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt)) info = OpaqueClosureCallInfo(match) if !result.edgecycle - const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes, - match, sv, closure.isva) + const_result = abstract_call_method_with_const_args(interp, result, closure, + ArgInfo(nothing, argtypes), match, sv, closure.isva) if const_result !== nothing const_rettype, const_result = const_result if const_rettype ⊑ rt @@ -1365,9 +1395,9 @@ function most_general_argtypes(closure::PartialOpaque) end # call where the function is any lattice element -function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, +function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) - #print("call ", e.args[1], argtypes, "\n\n") + argtypes = arginfo.argtypes ft = argtypes[1] f = singleton_type(ft) if isa(ft, PartialOpaque) @@ -1381,9 +1411,9 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ add_remark!(interp, sv, "Could not identify method table for call") return CallMeta(Any, false) end - return abstract_call_gf_by_type(interp, nothing, fargs, argtypes, argtypes_to_type(argtypes), sv, max_methods) + return abstract_call_gf_by_type(interp, nothing, arginfo, argtypes_to_type(argtypes), sv, max_methods) end - return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods) + return abstract_call_known(interp, f, arginfo, sv, max_methods) end function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool) @@ -1433,7 +1463,7 @@ function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::V # this may be the wrong world for the call, # but some of the result is likely to be valid anyways # and that may help generate better codegen - abstract_call(interp, nothing, at, sv) + abstract_call(interp, ArgInfo(nothing, at), sv) nothing end @@ -1507,7 +1537,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), if argtypes === nothing t = Bottom else - callinfo = abstract_call(interp, ea, argtypes, sv) + callinfo = abstract_call(interp, ArgInfo(ea, argtypes), sv) sv.stmt_info[sv.currpc] = callinfo.info t = callinfo.rt end diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index cbeeb84b464bb6..8f934c5133a6ff 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -3,7 +3,8 @@ function is_argtype_match(@nospecialize(given_argtype), @nospecialize(cache_argtype), overridden_by_const::Bool) - if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || isa(given_argtype, PartialOpaque) + if isa(given_argtype, Const) || isa(given_argtype, PartialStruct) || + isa(given_argtype, PartialOpaque) || isa(given_argtype, Conditional) return is_lattice_equal(given_argtype, cache_argtype) end return !overridden_by_const @@ -13,10 +14,28 @@ end # for the provided `linfo` and `given_argtypes`. The purpose of this function is # to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`, # so that we can construct cache-correct `InferenceResult`s in the first place. -function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool) +function matching_cache_argtypes( + linfo::MethodInstance, (; fargs, argtypes)::ArgInfo, va_override::Bool) @assert isa(linfo.def, Method) # ensure the next line works nargs::Int = linfo.def.nargs - given_argtypes = anymap(widenconditional, given_argtypes) + given_argtypes = Vector{Any}(undef, length(argtypes)) + local condargs = nothing + for i in 1:length(argtypes) + argtype = argtypes[i] + # forward `Conditional` if it conveys a constraint on any other argument + if isa(argtype, Conditional) && fargs !== nothing + slotid = find_constrained_arg(argtype, fargs) + if slotid !== nothing + if condargs === nothing + condargs = Tuple{Int,Int}[] + end + push!(condargs, (slotid, i)) + given_argtypes[i] = Conditional(SlotNumber(slotid), argtype.vtype, argtype.elsetype) + continue + end + end + given_argtypes[i] = widenconditional(argtype) + end isva = va_override || linfo.def.isva if isva || isvarargtype(given_argtypes[end]) isva_given_argtypes = Vector{Any}(undef, nargs) @@ -30,6 +49,14 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, last = nargs end isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end]) + # invalidate `Conditional` imposed on varargs + if condargs !== nothing + for (slotid, i) in condargs + if slotid ≥ last + isva_given_argtypes[i] = widenconditional(isva_given_argtypes[i]) + end + end + end end given_argtypes = isva_given_argtypes end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 67294503fed7a6..9cdf708952c289 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -962,7 +962,7 @@ function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any v = unwrapva(argtypes[5]) TF = getfield_tfunc(o, f) push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call - callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1) + callinfo = abstract_call(interp, ArgInfo(nothing, Any[op, TF, v]), sv, #=max_methods=# 1) pop!(sv.ssavalue_uses[sv.currpc], sv.currpc) TF2 = tmeet(callinfo.rt, widenconst(TF)) if TF2 === Bottom @@ -1744,7 +1744,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s if contains_is(argtypes_vec, Union{}) return CallMeta(Const(Union{}), false) end - call = abstract_call(interp, nothing, argtypes_vec, sv, -1) + call = abstract_call(interp, ArgInfo(nothing, argtypes_vec), sv, -1) info = verbose_stmt_info(interp) ? ReturnTypeCallInfo(call.info) : false rt = widenconditional(call.rt) if isa(rt, Const) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index f9d88fc2ac2609..5102ef132badae 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -17,6 +17,11 @@ If `interp` is an `AbstractInterpreter`, it is expected to provide at least the """ abstract type AbstractInterpreter end +struct ArgInfo + fargs::Union{Nothing,Vector{Any}} + argtypes::Vector{Any} +end + """ InferenceResult @@ -29,8 +34,10 @@ mutable struct InferenceResult result # ::Type, or InferenceState if WIP src #::Union{CodeInfo, OptimizationState, Nothing} # if inferred copy is available valid_worlds::WorldRange # if inference and optimization is finished - function InferenceResult(linfo::MethodInstance, given_argtypes = nothing, va_override=false) - argtypes, overridden_by_const = matching_cache_argtypes(linfo, given_argtypes, va_override) + function InferenceResult(linfo::MethodInstance; + arginfo::Union{Nothing,ArgInfo} = nothing, + va_override::Bool = false) + argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo, va_override) return new(linfo, argtypes, overridden_by_const, Any, nothing, WorldRange()) end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index ba42785cb1d4fd..87c0e4e773c2bc 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2000,6 +2000,55 @@ function _g_ifelse_isa_() end @test Base.return_types(_g_ifelse_isa_, ()) == [Int] +@testset "Conditional forwarding" begin + # forward `Conditional` if it conveys a constraint on any other argument + ifelselike(cnd, x, y) = cnd ? x : y + @test Base.return_types((Any,Int,)) do x, y + ifelselike(isa(x, Int), x, y) + end |> only == Int + + # demonstrate extra constraint propagation for Base.ifelse + @test Base.return_types((Any,Int,)) do x, y + ifelse(isa(x, Int), x, y) + end |> only == Int + + @test Base.return_types((Any,Int)) do x, y + ifelse(!isa(x, Int), y, x) + end |> only == Int + + @test Base.return_types((Any,Int)) do x, y + a = ifelse(x === 0, x, 0) # ::Const(0) + if a == 0 + return y + else + return nothing # dead branch + end + end |> only == Int + + # pick up the first if there are multiple constrained arguments + @test Base.return_types((Any,)) do x + ifelse(isa(x, Int), x, x) + end |> only == Any + + # just propagate multiple constraints + ifelselike2(cnd1, cnd2, x, y, z) = cnd1 ? x : cnd2 ? y : z + @test Base.return_types((Any,Any)) do x, y + ifelselike2(isa(x, Int), isa(y, Int), x, y, 0) + end |> only == Int + + # work with `invoke` + @test Base.return_types((Any,Any)) do x, y + Base.@invoke ifelselike(isa(x, Int), x, y::Int) + end |> only == Int + + # don't be confused with vararg method + vacond(cnd, va...) = cnd ? va : 0 + @test Base.return_types((Any,)) do x + # at runtime we will see `va::Tuple{Tuple{Int,Int}, Tuple{Int,Int}}` + vacond(isa(x, Tuple{Int,Int}), x, x) + end |> only == Union{Int,Tuple{Any,Any}} +end + # Equivalence of Const(T.instance) and T for singleton types @test Const(nothing) ⊑ Nothing && Nothing ⊑ Const(nothing)