Skip to content

Commit

Permalink
optimizer: inline abstract union-split callsite
Browse files Browse the repository at this point in the history
Currently the optimizer handles abstract callsite only when there is a
single dispatch candidate (in most cases), and so inlining and static-dispatch
are prohibited when the callsite is union-split (in other word, union-split
happens only when all the dispatch candidates are concrete).

However, there are certain patterns of code (most notably our Julia-level compiler code)
that inherently need to deal with abstract callsite.
The following example is taken from `Core.Compiler` utility:
```julia
julia> @inline isType(@nospecialize t) = isa(t, DataType) && t.name === Type.body.name
isType (generic function with 1 method)

julia> code_typed((Any,)) do x # abstract, but no union-split, successful inlining
           isType(x)
       end |> only
CodeInfo(
1 ─ %1 = (x isa Main.DataType)::Bool
└──      goto #3 if not %1
2 ─ %3 = π (x, DataType)
│   %4 = Base.getfield(%3, :name)::Core.TypeName
│   %5 = Base.getfield(Type{T}, :name)::Core.TypeName
│   %6 = (%4 === %5)::Bool
└──      goto #4
3 ─      goto #4
4 ┄ %9 = φ (#2 => %6, #3 => false)::Bool
└──      return %9
) => Bool

julia> code_typed((Union{Type,Nothing},)) do x # abstract, union-split, unsuccessful inlining
           isType(x)
       end |> only
CodeInfo(
1 ─ %1 = (isa)(x, Nothing)::Bool
└──      goto #3 if not %1
2 ─      goto #4
3 ─ %4 = Main.isType(x)::Bool
└──      goto #4
4 ┄ %6 = φ (#2 => false, #3 => %4)::Bool
└──      return %6
) => Bool
```
(note that this is a limitation of the inlining algorithm, and so any
user-provided hints like callsite inlining annotation doesn't help here)

This commit enables inlining and static dispatch for abstract union-split callsite.
The core idea here is that we can simulate our dispatch semantics by
generating `isa` checks in order of the specialities of dispatch candidates:
```julia
julia> code_typed((Union{Type,Nothing},)) do x # union-split, unsuccessful inlining
                  isType(x)
              end |> only
CodeInfo(
1 ─ %1  = (isa)(x, Nothing)::Bool
└──       goto #3 if not %1
2 ─       goto #9
3 ─ %4  = (isa)(x, Type)::Bool
└──       goto #8 if not %4
4 ─ %6  = π (x, Type)
│   %7  = (%6 isa Main.DataType)::Bool
└──       goto #6 if not %7
5 ─ %9  = π (%6, DataType)
│   %10 = Base.getfield(%9, :name)::Core.TypeName
│   %11 = Base.getfield(Type{T}, :name)::Core.TypeName
│   %12 = (%10 === %11)::Bool
└──       goto #7
6 ─       goto #7
7 ┄ %15 = φ (#5 => %12, #6 => false)::Bool
└──       goto #9
8 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└──       unreachable
9 ┄ %19 = φ (#2 => false, #7 => %15)::Bool
└──       return %19
) => Bool
```

Inlining/static-dispatch of abstract union-split callsite will improve
the performance in such situations (and so this commit will improve the
latency of our JIT compilation). Especially, this commit helps us avoid
excessive specializations of `Core.Compiler` code by statically-resolving
`@nospecialize`d callsites, and as the result, the # of precompiled
statements is now reduced from  `1956` ([`master`](dc45d77)) to `1901` (this commit).

And also, as a side effect, the implementation of our inlining algorithm
gets much simplified now since we no longer need the previous special
handlings for abstract callsites.

One possible drawback would be increased code size.
This change seems to certainly increase the size of sysimage,
but I think these numbers are in an acceptable range:
> [`master`](dc45d77)
```
❯ du -sh usr/lib/julia/*
 17M    usr/lib/julia/corecompiler.ji
188M    usr/lib/julia/sys-o.a
164M    usr/lib/julia/sys.dylib
 23M    usr/lib/julia/sys.dylib.dSYM
101M    usr/lib/julia/sys.ji
```

> this commit
```
❯ du -sh usr/lib/julia/*
 17M    usr/lib/julia/corecompiler.ji
190M    usr/lib/julia/sys-o.a
166M    usr/lib/julia/sys.dylib
 23M    usr/lib/julia/sys.dylib.dSYM
102M    usr/lib/julia/sys.ji
```
  • Loading branch information
aviatesk committed Mar 15, 2022
1 parent b2890d5 commit 5cf29ed
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 90 deletions.
134 changes: 49 additions & 85 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
push!(from_bbs, length(state.new_cfg_blocks))
# TODO: Right now we unconditionally generate a fallback block
# in case of subtyping errors - This is probably unnecessary.
if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)))
if i != length(cases) || (!fully_covered || (!params.trust_inference))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand Down Expand Up @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
spec = item.spec::ResolvedInliningSpec
sparam_vals = item.mi.sparam_vals
def = item.mi.def::Method
inline_cfg = spec.ir.cfg
linetable_offset::Int32 = length(linetable)
# Append the linetable of the inlined function to our line table
inlined_at = Int(compact.result[idx][:line])
Expand Down Expand Up @@ -471,17 +470,17 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
join_bb = bbs[end]
pn = PhiNode()
local bb = compact.active_result_bb
@assert length(bbs) >= length(cases)
for i in 1:length(cases)
ncases = length(cases)
@assert length(bbs) >= ncases
for i = 1:ncases
ithcase = cases[i]
mtype = ithcase.sig::DataType # checked within `handle_cases!`
case = ithcase.item
next_cond_bb = bbs[i]
cond = true
nparams = fieldcount(atype)
@assert nparams == fieldcount(mtype)
if i != length(cases) || !fully_covered ||
(!params.trust_inference && isdispatchtuple(cases[i].sig))
if i != ncases || !fully_covered || !params.trust_inference
for i = 1:nparams
a, m = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -538,7 +537,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
bb += 1
# We're now in the fall through block, decide what to do
if fully_covered
if !params.trust_inference && isdispatchtuple(cases[end].sig)
if !params.trust_inference
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
insert_node_here!(compact, NewInstruction(e, Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
Expand All @@ -561,7 +560,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
state = CFGInliningState(ir)
for (idx, item) in todo
if isa(item, UnionSplit)
cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params)
cfg_inline_unionsplit!(ir, idx, item, state, params)
else
item = item::InliningTodo
spec = item.spec::ResolvedInliningSpec
Expand Down Expand Up @@ -1175,12 +1174,8 @@ function analyze_single_call!(
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
argtypes = sig.argtypes
cases = InliningCase[]
local only_method = nothing # keep track of whether there is one matching method
local meth::MethodLookupResult
local any_fully_covered = false
local handled_all_cases = true
local any_covers_full = false
local revisit_idx = nothing

for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1191,66 +1186,20 @@ function analyze_single_call!(
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end
for (j, match) in enumerate(meth)
any_covers_full |= match.fully_covers
if !isdispatchtuple(match.spec_types)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
end
for match in meth
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
any_fully_covered |= match.fully_covers
end
end

atype = argtypes_to_type(argtypes)
if handled_all_cases && revisit_idx !== nothing
# If there's only one case that's not a dispatchtuple, we can
# still unionsplit by visiting all the other cases first.
# This is useful for code like:
# foo(x::Int) = 1
# foo(@nospecialize(x::Any)) = 2
# where we where only a small number of specific dispatchable
# cases are split off from an ::Any typed fallback.
(i, j) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple.
# -- But don't try it if we already tried to handle the match in the revisit_idx
# case, because that'll (necessarily) be the same method.
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
@assert length(meth) == 1
match = meth[1]
end
handle_match!(match, argtypes, flag, state, cases, true) || return nothing
any_covers_full = handled_all_cases = match.fully_covers
if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
end

# similar to `analyze_single_call!`, but with constant results
Expand All @@ -1261,8 +1210,8 @@ function handle_const_call!(
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local any_covers_full = false
local j = 0
for i in 1:length(infos)
meth = infos[i].results
Expand All @@ -1278,42 +1227,39 @@ function handle_const_call!(
for match in meth
j += 1
result = results[j]
any_covers_full |= match.fully_covers
any_fully_covered |= match.fully_covers
if isa(result, ConstResult)
case = const_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, InferenceResult)
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
end
end
end

# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
atype = argtypes_to_type(argtypes)
if length(cases) == 0
length(results) == 1 || return nothing
result = results[1]
isa(result, InferenceResult) || return nothing
handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing
spec_types = cases[1].sig
any_covers_full = handled_all_cases = atype <: spec_types
if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# we may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return false
_any(case->case.sig === spec_types, cases) && return true
push!(cases, InliningCase(spec_types, item))
return true
end
Expand Down Expand Up @@ -1349,7 +1295,24 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype),
handle_single_case!(ir, idx, stmt, cases[1].item, todo, params)
elseif length(cases) > 0
isa(atype, DataType) || return nothing
all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing
# `ir_inline_unionsplit!` is going to generate `isa` checks corresponding to the
# signatures of union-split dispatch candidates in order to simulate the dispatch
# semantics, and inline their bodies into each `isa`-conditional block -- and since
# we may deal with abstract union-split callsites here, these dispatch candidates
# need to be sorted in order of their signature specificity.
# Fortunately, ml_matches already sorted them in that way, so we can just process
# them in order, as far as we haven't changed their order somewhere up to this point.
ncases = length(cases)
for i = 1:ncases
sigᵢ = cases[i].sig
isa(sigᵢ, DataType) || return nothing
for j = i+1:ncases
sigⱼ = cases[j].sig
# since we already bail out from ambiguous case, we can use `morespecific` as
# a strict total order of specificity (in a case when they don't have a type intersection)
@assert !hasintersect(sigᵢ, sigⱼ) || morespecific(sigᵢ, sigⱼ) "invalid order of dispatch candidate"
end
end
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
return nothing
Expand Down Expand Up @@ -1445,7 +1408,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)

analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo)
end
todo

return todo
end

function linear_inline_eligible(ir::IRCode)
Expand Down
2 changes: 1 addition & 1 deletion base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Sort
import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base.Order
using .Base: copymutable, LinearIndices, length, (:),
using .Base: copymutable, LinearIndices, length, (:), iterate,
eachindex, axes, first, last, similar, zip, OrdinalRange,
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
Expand Down
78 changes: 74 additions & 4 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,76 @@ let
@test invoke(Any[10]) === false
end

# test union-split, non-dispatchtuple callsite inlining

@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any)
@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number)
let src = code_typed1((Any,)) do x
abstract_unionsplit(x)
end
@test count(isinvoke(:abstract_unionsplit), src.code) == 2
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end
let src = code_typed1((Union{Type,Number},)) do x
abstract_unionsplit(x)
end
@test count(isinvoke(:abstract_unionsplit), src.code) == 2
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end

@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any)
@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number)
let src = code_typed1((Any,)) do x
abstract_unionsplit_fallback(x)
end
@test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2
@test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch
end
let src = code_typed1((Union{Type,Number},)) do x
abstract_unionsplit_fallback(x)
end
@test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end

@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x))
@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x))
let src = code_typed1((Any,)) do x
abstract_unionsplit(false, x)
end
@test count(iscall((src, typeof)), src.code) == 2
@test count(isinvoke(:println), src.code) == 0
@test count(iscall((src, println)), src.code) == 0
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end
let src = code_typed1((Union{Type,Number},)) do x
abstract_unionsplit(false, x)
end
@test count(iscall((src, typeof)), src.code) == 2
@test count(isinvoke(:println), src.code) == 0
@test count(iscall((src, println)), src.code) == 0
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end

@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Type) = (c && println("erase me"); typeof(x))
@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x))
let src = code_typed1((Any,)) do x
abstract_unionsplit_fallback(false, x)
end
@test count(iscall((src, typeof)), src.code) == 2
@test count(isinvoke(:println), src.code) == 0
@test count(iscall((src, println)), src.code) == 0
@test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch
end
let src = code_typed1((Union{Type,Number},)) do x
abstract_unionsplit_fallback(false, x)
end
@test count(iscall((src, typeof)), src.code) == 2
@test count(isinvoke(:println), src.code) == 0
@test count(iscall((src, println)), src.code) == 0
@test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch
end

# issue 43104

@inline isGoodType(@nospecialize x::Type) =
Expand Down Expand Up @@ -1097,11 +1167,11 @@ end

global x44200::Int = 0
function f44200()
global x = 0
while x < 10
x += 1
global x44200 = 0
while x44200 < 10
x44200 += 1
end
x
x44200
end
let src = code_typed1(f44200)
@test count(x -> isa(x, Core.PiNode), src.code) == 0
Expand Down

0 comments on commit 5cf29ed

Please sign in to comment.