From ce219118a8ba1ed1e84a8757acad80684c97cbcc Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 25 Oct 2021 01:35:12 +0900 Subject: [PATCH] optimizer: fix #42754, inline union-split const-prop'ed sources MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit complements #39754 and #39305: implements a logic to use constant-prop'ed results for inlining at union-split callsite. Currently it works only for cases when constant-prop' succeeded for all (union-split) signatures. > example ```julia julia> mutable struct X # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types a::Union{Nothing, Int} b::Symbol end; julia> code_typed((X, Union{Nothing,Int})) do x, a # this `setproperty` call would be union-split and constant-prop will happen for # each signature: inlining would fail if we don't use constant-prop'ed source # since the approximated inlining cost of `convert(fieldtype(X, sym), a)` would # end up very high if we don't propagate `sym::Const(:a)` x.a = a x end |> only |> first ``` > before this commit ```julia CodeInfo( 1 ─ %1 = Base.setproperty!::typeof(setproperty!) │ %2 = (isa)(a, Nothing)::Bool └── goto #3 if not %2 2 ─ %4 = π (a, Nothing) │ invoke %1(_2::X, :a::Symbol, %4::Nothing)::Any └── goto #6 3 ─ %7 = (isa)(a, Int64)::Bool └── goto #5 if not %7 4 ─ %9 = π (a, Int64) │ invoke %1(_2::X, :a::Symbol, %9::Int64)::Any └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ``` > after this commit ```julia CodeInfo( 1 ─ %1 = (isa)(a, Nothing)::Bool └── goto #3 if not %1 2 ─ Base.setfield!(x, :a, nothing)::Nothing └── goto #6 3 ─ %5 = (isa)(a, Int64)::Bool └── goto #5 if not %5 4 ─ %7 = π (a, Int64) │ Base.setfield!(x, :a, %7)::Int64 └── goto #6 5 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 6 ┄ return x ) ``` --- base/compiler/ssair/inlining.jl | 147 ++++++++++++++++++++++---------- test/compiler/inline.jl | 77 +++++++++++++++++ 2 files changed, 177 insertions(+), 47 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 8cece4cf21657f..7c622e50482e5a 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1147,9 +1147,10 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta return sig end +# TODO inline non-`isdispatchtuple`, union-split callsites function analyze_single_call!( ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt), - sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8) + (; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8) cases = InliningCase[] local signature_union = Bottom local only_method = nothing # keep track of whether there is one matching method @@ -1181,7 +1182,7 @@ function analyze_single_call!( fully_covered = false continue end - item = analyze_method!(match, sig.atypes, state, flag) + item = analyze_method!(match, atypes, state, flag) if item === nothing fully_covered = false continue @@ -1192,25 +1193,25 @@ function analyze_single_call!( end end - signature_fully_covered = sig.atype <: signature_union - # If we're fully covered and there's only one applicable method, - # we inline, even if the signature is not a dispatch tuple - if signature_fully_covered && length(cases) == 0 && only_method isa Method - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - sig.atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp, only_method, true) - else - meth = meth::MethodLookupResult - @assert length(meth) == 1 - match = meth[1] + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple + if atype <: signature_union + if length(cases) == 0 && only_method isa Method + if length(infos) > 1 + (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), + atype, only_method.sig)::SimpleVector + match = MethodMatch(metharg, methsp, only_method, true) + else + meth = meth::MethodLookupResult + @assert length(meth) == 1 + match = meth[1] + end + item = analyze_method!(match, atypes, state, flag) + item === nothing && return + push!(cases, InliningCase(match.spec_types, item)) + fully_covered = true end - fully_covered = true - item = analyze_method!(match, sig.atypes, state, flag) - item === nothing && return - push!(cases, InliningCase(match.spec_types, item)) - end - if !signature_fully_covered + else fully_covered = false end @@ -1219,36 +1220,81 @@ function analyze_single_call!( # onto the todo list if fully_covered && length(cases) == 1 handle_single_case!(ir, stmt, idx, cases[1].item, false, todo) - return + elseif length(cases) > 0 + push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end - length(cases) == 0 && return - push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases)) return nothing end +# try to create `InliningCase`s using constant-prop'ed results +# currently it works only when constant-prop' succeeded for all (union-split) signatures +# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later +# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out function maybe_handle_const_call!( - ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature, + ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature, state::InliningState, flag::UInt8, isinvoke::Bool, todo::Vector{Pair{Int, Any}}) - # when multiple matches are found, bail out and later inliner will union-split this signature - # TODO effectively use multiple constant analysis results here - length(info.results) == 1 || return false - result = info.results[1] - isa(result, InferenceResult) || return false - - (; mi) = item = InliningTodo(result, sig.atypes) - validate_sparams(mi.sparam_vals) || return true - state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - if sig.atype <: mi.def.sig - handle_single_case!(ir, stmt, idx, item, isinvoke, todo) - return true + cases = InliningCase[] # TODO avoid this allocation for single cases ? + local fully_covered = true + local signature_union = Bottom + for result in results + isa(result, InferenceResult) || return false + (; mi) = item = InliningTodo(result, atypes) + spec_types = mi.specTypes + signature_union = Union{signature_union, spec_types} + if !isdispatchtuple(spec_types) + fully_covered = false + continue + end + if !validate_sparams(mi.sparam_vals) + fully_covered = false + continue + end + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + if item === nothing + fully_covered = false + continue + end + push!(cases, InliningCase(spec_types, item)) + end + + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple + if atype <: signature_union + if length(cases) == 0 && length(results) == 1 + (; mi) = item = InliningTodo(results[1]::InferenceResult, atypes) + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + validate_sparams(mi.sparam_vals) || return true + item === nothing && return true + push!(cases, InliningCase(mi.specTypes, item)) + fully_covered = true + end else - item === nothing && return true - # Union split out the error case - item = UnionSplit(false, sig.atype, InliningCase[InliningCase(mi.specTypes, item)]) + fully_covered = false + end + + # If we only have one case and that case is fully covered, we may either + # be able to do the inlining now (for constant cases), or push it directly + # onto the todo list + if fully_covered && length(cases) == 1 + handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo) + elseif length(cases) > 0 isinvoke && rewrite_invoke_exprargs!(stmt) - push!(todo, idx=>item) - return true + push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end + return true +end + +function handle_const_opaque_closure_call!( + ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, + (; atypes)::Signature, state::InliningState, flag::UInt8, todo::Vector{Pair{Int, Any}}) + @assert length(results) == 1 + result = results[1]::InferenceResult + item = InliningTodo(result, atypes) + isdispatchtuple(item.mi.specTypes) || return + validate_sparams(item.mi.sparam_vals) || return + state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) + handle_single_case!(ir, stmt, idx, item, false, todo) + return nothing end function assemble_inline_todo!(ir::IRCode, state::InliningState) @@ -1283,18 +1329,25 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) # if inference arrived here with constant-prop'ed result(s), # we can perform a specialized analysis for just this case if isa(info, ConstCallInfo) - if !is_stmt_noinline(flag) && maybe_handle_const_call!( - ir, idx, stmt, info, sig, - state, flag, sig.f === Core.invoke, todo) - continue + if !is_stmt_noinline(flag) + if isa(info.call, OpaqueClosureCallInfo) + handle_const_opaque_closure_call!( + ir, idx, stmt, info, + sig, state, flag, todo) + continue + else + maybe_handle_const_call!( + ir, idx, stmt, info, sig, + state, flag, sig.f === Core.invoke, todo) && continue + end else info = info.call end end if isa(info, OpaqueClosureCallInfo) - result = analyze_method!(info.match, sig.atypes, state, flag) - handle_single_case!(ir, stmt, idx, result, false, todo) + item = analyze_method!(info.match, sig.atypes, state, flag) + handle_single_case!(ir, stmt, idx, item, false, todo) continue end diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 6bdb71bf8f292c..a891937c729423 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -680,3 +680,80 @@ let f(x) = (x...,) # the the original apply call is not union-split, but the inserted `iterate` call is. @test code_typed(f, Tuple{Union{Int64, CartesianIndex{1}, CartesianIndex{3}}})[1][2] == Tuple{Int64} end + +# https://github.com/JuliaLang/julia/issues/42754 +# inline union-split constant-prop'ed sources +mutable struct X42754 + # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types + a::Union{Nothing, Int} + b::Symbol +end +let code = code_typed1((X42754, Union{Nothing,Int})) do x, a + # this `setproperty` call would be union-split and constant-prop will happen for + # each signature: inlining would fail if we don't use constant-prop'ed source + # since the approximate inlining cost of `convert(fieldtype(X, sym), a)` would + # end up very high if we don't propagate `sym::Const(:a)` + x.a = a + x + end + @test all(code) do @nospecialize(x) + isinvoke(x, :setproperty!) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :setproperty! && return false + end + return true + end +end + +import Base: @constprop + +# test single, non-dispatchtuple callsite inlining + +@constprop :none @inline test_single_nondispatchtuple(@nospecialize(t)) = + isa(t, DataType) && t.name === Type.body.name +let + code = code_typed1((Any,)) do x + test_single_nondispatchtuple(x) + end + @test all(code) do @nospecialize(x) + isinvoke(x, :test_single_nondispatchtuple) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false + end + return true + end +end + +@constprop :aggressive @inline test_single_nondispatchtuple(c, @nospecialize(t)) = + c && isa(t, DataType) && t.name === Type.body.name +let + code = code_typed1((Any,)) do x + test_single_nondispatchtuple(true, x) + end + @test all(code) do @nospecialize(x) + isinvoke(x, :test_single_nondispatchtuple) && return false + if Meta.isexpr(x, :call) + f = x.args[1] + isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false + end + return true + end +end + +# validate inlining processing + +@constprop :none @inline validate_unionsplit_inlining(@nospecialize(t)) = throw("invalid inlining processing detected") +@constprop :none @noinline validate_unionsplit_inlining(i::Integer) = (println(IOBuffer(), "prevent inlining"); false) +let + invoke(xs) = validate_unionsplit_inlining(xs[1]) + @test invoke(Any[10]) === false +end + +@constprop :aggressive @inline validate_unionsplit_inlining(c, @nospecialize(t)) = c && throw("invalid inlining processing detected") +@constprop :aggressive @noinline validate_unionsplit_inlining(c, i::Integer) = c && (println(IOBuffer(), "prevent inlining"); false) +let + invoke(xs) = validate_unionsplit_inlining(true, xs[1]) + @test invoke(Any[10]) === false +end