From 408c140746133f2c22b154aa813968452f55d231 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 27 Oct 2021 02:17:26 +0900 Subject: [PATCH] optimizer: inline abstract union-split callsite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently the optimizer handles abstract callsite only when there is a single dispatch candidate (in most cases), and so inlining and static-dispatch are prohibited when the callsite is union-split (in other word, union-split happens only when all the dispatch candidates are concrete). However, there are certain patterns of code (most notably our Julia-level compiler code) that inherently need to deal with abstract callsite. The following example is taken from `Core.Compiler` utility: ```julia julia> @inline isType(@nospecialize t) = isa(t, DataType) && t.name === Type.body.name isType (generic function with 1 method) julia> code_typed((Any,)) do x # abstract, but no union-split, successful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (x isa Main.DataType)::Bool └── goto #3 if not %1 2 ─ %3 = π (x, DataType) │ %4 = Base.getfield(%3, :name)::Core.TypeName │ %5 = Base.getfield(Type{T}, :name)::Core.TypeName │ %6 = (%4 === %5)::Bool └── goto #4 3 ─ goto #4 4 ┄ %9 = φ (#2 => %6, #3 => false)::Bool └── return %9 ) => Bool julia> code_typed((Union{Type,Nothing},)) do x # abstract, union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #4 3 ─ %4 = Main.isType(x)::Bool └── goto #4 4 ┄ %6 = φ (#2 => false, #3 => %4)::Bool └── return %6 ) => Bool ``` (note that this is a limitation of the inlining algorithm, and so any user-provided hints like callsite inlining annotation doesn't help here) This commit enables inlining and static dispatch for abstract union-split callsite. The core idea here is that we can simulate our dispatch semantics by generating `isa` checks in order of the specialities of dispatch candidates: ```julia julia> code_typed((Union{Type,Nothing},)) do x # union-split, unsuccessful inlining isType(x) end |> only CodeInfo( 1 ─ %1 = (isa)(x, Nothing)::Bool └── goto #3 if not %1 2 ─ goto #9 3 ─ %4 = (isa)(x, Type)::Bool └── goto #8 if not %4 4 ─ %6 = π (x, Type) │ %7 = (%6 isa Main.DataType)::Bool └── goto #6 if not %7 5 ─ %9 = π (%6, DataType) │ %10 = Base.getfield(%9, :name)::Core.TypeName │ %11 = Base.getfield(Type{T}, :name)::Core.TypeName │ %12 = (%10 === %11)::Bool └── goto #7 6 ─ goto #7 7 ┄ %15 = φ (#5 => %12, #6 => false)::Bool └── goto #9 8 ─ Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{} └── unreachable 9 ┄ %19 = φ (#2 => false, #7 => %15)::Bool └── return %19 ) => Bool ``` Inlining/static-dispatch of abstract union-split callsite will improve the performance in such situations (and so this commit will improve the latency of our JIT compilation). Especially, this commit helps us avoid excessive specializations of `Core.Compiler` code by statically-resolving `@nospecialize`d callsites, and as the result, the # of precompiled statements is now reduced from `1956` ([`master`](dc45d776a900ef17581a842952c51297065afa3a)) to `1901` (this commit). And also, as a side effect, the implementation of our inlining algorithm gets much simplified now since we no longer need the previous special handlings for abstract callsites. One possible drawback would be increased code size. This change seems to certainly increase the size of sysimage, but I think these numbers are in an acceptable range: > [`master`](dc45d776a900ef17581a842952c51297065afa3a) ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 188M usr/lib/julia/sys-o.a 164M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 101M usr/lib/julia/sys.ji ``` > this commit ``` ❯ du -sh usr/lib/julia/* 17M usr/lib/julia/corecompiler.ji 190M usr/lib/julia/sys-o.a 166M usr/lib/julia/sys.dylib 23M usr/lib/julia/sys.dylib.dSYM 102M usr/lib/julia/sys.ji ``` --- base/compiler/ssair/inlining.jl | 179 +++++++++++++++++--------------- base/sort.jl | 2 +- test/compiler/inline.jl | 105 ++++++++++++++++++- 3 files changed, 196 insertions(+), 90 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index c06ddfbbcca8d..73d373f15b9c3 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig))) + if i != length(cases) || (!fully_covered || (!params.trust_inference)) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector spec = item.spec::ResolvedInliningSpec sparam_vals = item.mi.sparam_vals def = item.mi.def::Method - inline_cfg = spec.ir.cfg linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) @@ -459,6 +458,66 @@ end const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)") +""" + ir_inline_unionsplit! + +The core idea of this function is to simulate the dispatch semantics by generating +(flat) `isa`-checks corresponding to the signatures of union-split dispatch candidates, +and then inline their bodies into each `isa`-conditional block. +This `isa`-based virtual dispatch requires few pre-conditions to hold in order to simulate +the actual semantics correctly. + +The first one is that these dispatch candidates need to be processed in order of their specificity, +and the corresponding `isa`-checks should reflect the method specificities, since now their +signatures are not necessarily concrete. +For example, given the following definitions: + + f(x::Int) = ... + f(x::Number) = ... + f(x::Any) = ... + +and a callsite: + + f(x::Any) + +then a correct `isa`-based virtual dispatch would be: + + if isa(x, Int) + [inlined/resolved f(x::Int)] + elseif isa(x, Number) + [inlined/resolved f(x::Number)] + else # implies `isa(x, Any)`, which fully covers this call signature, + # otherwise we need to insert a fallback dynamic dispatch case also + [inlined/resolved f(x::Any)] + end + +Fortunately, `ml_matches` should already sorted them in that way, except cases when there is +any ambiguity, from which we already bail out at this point. + +Another consideration is type equality constraint from type variables: the `isa`-checks are +not enough to simulate the dispatch semantics in cases like: +Given a definition: + + g(x::T, y::T) where T<:Integer = ... + +transform a callsite: + + g(x::Any, y::Any) + +into the optimized form: + + if isa(x, Integer) && isa(y, Integer) + [inlined/resolved g(x::Integer, y::Integer)] + else + g(x, y) # fallback dynamic dispatch + end + +But again, we should already bail out from such cases at this point, essentially by +excluding cases where `case.sig::UnionAll`. + +In short, here we can process the dispatch candidates in order, assuming we haven't changed +their order somehow somewhere up to this point. +""" function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, (; fully_covered, atype, cases, bbs)::UnionSplit, @@ -468,8 +527,9 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, join_bb = bbs[end] pn = PhiNode() local bb = compact.active_result_bb - @assert length(bbs) >= length(cases) - for i in 1:length(cases) + ncases = length(cases) + @assert length(bbs) >= ncases + for i = 1:ncases ithcase = cases[i] mtype = ithcase.sig::DataType # checked within `handle_cases!` case = ithcase.item @@ -477,8 +537,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true nparams = fieldcount(atype) @assert nparams == fieldcount(mtype) - if i != length(cases) || !fully_covered || - (!params.trust_inference && isdispatchtuple(cases[i].sig)) + if i != ncases || !fully_covered || !params.trust_inference for i = 1:nparams a, m = fieldtype(atype, i), fieldtype(mtype, i) # If this is always true, we don't need to check for it @@ -535,7 +594,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - if !params.trust_inference && isdispatchtuple(cases[end].sig) + if !params.trust_inference e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -558,7 +617,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect state = CFGInliningState(ir) for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params) + cfg_inline_unionsplit!(ir, idx, item, state, params) else item = item::InliningTodo spec = item.spec::ResolvedInliningSpec @@ -1172,12 +1231,8 @@ function analyze_single_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) argtypes = sig.argtypes cases = InliningCase[] - local only_method = nothing # keep track of whether there is one matching method - local meth::MethodLookupResult + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false - local revisit_idx = nothing - for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1188,66 +1243,20 @@ function analyze_single_call!( # No applicable methods; try next union split handled_all_cases = false continue - else - if length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method - only_method = false - end - else - only_method = false - end end - for (j, match) in enumerate(meth) - any_covers_full |= match.fully_covers - if !isdispatchtuple(match.spec_types) - if !match.fully_covers - handled_all_cases = false - continue - end - if revisit_idx === nothing - revisit_idx = (i, j) - else - handled_all_cases = false - revisit_idx = nothing - end - else - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) - end + for match in meth + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + any_fully_covered |= match.fully_covers end end - atype = argtypes_to_type(argtypes) - if handled_all_cases && revisit_idx !== nothing - # If there's only one case that's not a dispatchtuple, we can - # still unionsplit by visiting all the other cases first. - # This is useful for code like: - # foo(x::Int) = 1 - # foo(@nospecialize(x::Any)) = 2 - # where we where only a small number of specific dispatchable - # cases are split off from an ::Any typed fallback. - (i, j) = revisit_idx - match = infos[i].results[j] - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) - elseif length(cases) == 0 && only_method isa Method - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple. - # -- But don't try it if we already tried to handle the match in the revisit_idx - # case, because that'll (necessarily) be the same method. - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) - else - @assert length(meth) == 1 - match = meth[1] - end - handle_match!(match, argtypes, flag, state, cases, true) || return nothing - any_covers_full = handled_all_cases = match.fully_covers + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1258,8 +1267,8 @@ function handle_const_call!( (; call, results) = cinfo infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches cases = InliningCase[] + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false local j = 0 for i in 1:length(infos) meth = infos[i].results @@ -1275,32 +1284,26 @@ function handle_const_call!( for match in meth j += 1 result = results[j] - any_covers_full |= match.fully_covers + any_fully_covered |= match.fully_covers if isa(result, ConstResult) case = const_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, InferenceResult) - handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) + handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) end end end - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple - atype = argtypes_to_type(argtypes) - if length(cases) == 0 - length(results) == 1 || return nothing - result = results[1] - isa(result, InferenceResult) || return nothing - handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing - spec_types = cases[1].sig - any_covers_full = handled_all_cases = atype <: spec_types + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end function handle_match!( @@ -1308,9 +1311,12 @@ function handle_match!( cases::Vector{InliningCase}, allow_abstract::Bool = false) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false + # we may see duplicated dispatch signatures here when a signature gets widened + # during abstract interpretation: for the purpose of inlining, we can just skip + # processing this dispatch candidate + _any(case->case.sig === spec_types, cases) && return true item = analyze_method!(match, argtypes, flag, state) item === nothing && return false - _any(case->case.sig === spec_types, cases) && return true push!(cases, InliningCase(spec_types, item)) return true end @@ -1346,7 +1352,9 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), handle_single_case!(ir, idx, stmt, cases[1].item, todo, params) elseif length(cases) > 0 isa(atype, DataType) || return nothing - all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing + for case in cases + isa(case.sig, DataType) || return nothing + end push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end return nothing @@ -1442,7 +1450,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo) end - todo + + return todo end function linear_inline_eligible(ir::IRCode) diff --git a/base/sort.jl b/base/sort.jl index d26e9a4b09332..981eea35d96ab 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -5,7 +5,7 @@ module Sort import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base.Order -using .Base: copymutable, LinearIndices, length, (:), +using .Base: copymutable, LinearIndices, length, (:), iterate, eachindex, axes, first, last, similar, zip, OrdinalRange, AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline, AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !, diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index cc94ace0026df..f7f7b5e0e6c53 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -810,6 +810,103 @@ let @test invoke(Any[10]) === false end +# test union-split, non-dispatchtuple callsite inlining + +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Type) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +abstract_diagonal_dispatch(x::Int, y::Int) = 1 +abstract_diagonal_dispatch(x::Real, y::Int) = 2 +abstract_diagonal_dispatch(x::Int, y::Real) = 3 +function test_abstract_diagonal_dispatch(xs) + @test abstract_diagonal_dispatch(xs[1], xs[2]) == 1 + @test abstract_diagonal_dispatch(xs[3], xs[4]) == 3 + @test abstract_diagonal_dispatch(xs[5], xs[6]) == 2 + @test_throws MethodError abstract_diagonal_dispatch(xs[7], xs[8]) +end +test_abstract_diagonal_dispatch(Any[ + 1, 1, # => 1 + 1, 1.0, # => 3 + 1.0, 1, # => 2 + 1.0, 1.0 # => MethodError +]) + +constrained_dispatch(x::T, y::T) where T<:Real = 0 +let src = code_typed1((Real,Real,)) do x, y + constrained_dispatch(x, y) + end + @test any(iscall((src, constrained_dispatch)), src.code) # should account for MethodError +end +@test_throws MethodError let + x, y = 1.0, 1 + constrained_dispatch(x, y) +end + # issue 43104 @inline isGoodType(@nospecialize x::Type) = @@ -1119,11 +1216,11 @@ end global x44200::Int = 0 function f44200() - global x = 0 - while x < 10 - x += 1 + global x44200 = 0 + while x44200 < 10 + x44200 += 1 end - x + x44200 end let src = code_typed1(f44200) @test count(x -> isa(x, Core.PiNode), src.code) == 0