Skip to content

Commit

Permalink
optimizer: fix JuliaLang#42754, inline union-split const-prop'ed sources
Browse files Browse the repository at this point in the history
This commit complements JuliaLang#39754 and JuliaLang#39305: implements a logic to use
constant-prop'ed results for inlining at union-split callsite.
Currently it works only for cases when constant-prop' succeeded for all
(union-split) signatures.

> example
```julia
julia> mutable struct X
           # NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types
           a::Union{Nothing, Int}
           b::Symbol
       end;

julia> code_typed((X, Union{Nothing,Int})) do x, a
           # this `setproperty` call would be union-split and constant-prop will happen for
           # each signature: inlining would fail if we don't use constant-prop'ed source
           # since the approximated inlining cost of `convert(fieldtype(X, sym), a)` would
           # end up very high if we don't propagate `sym::Const(:a)`
           x.a = a
           x
       end |> only |> first
```

> before this commit
```julia
CodeInfo(
1 ─ %1 = Base.setproperty!::typeof(setproperty!)
│   %2 = (isa)(a, Nothing)::Bool
└──      goto #3 if not %2
2 ─ %4 = π (a, Nothing)
│        invoke %1(_2::X, 🅰️:Symbol, %4::Nothing)::Any
└──      goto #6
3 ─ %7 = (isa)(a, Int64)::Bool
└──      goto #5 if not %7
4 ─ %9 = π (a, Int64)
│        invoke %1(_2::X, 🅰️:Symbol, %9::Int64)::Any
└──      goto #6
5 ─      Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└──      unreachable
6 ┄      return x
)
```

> after this commit
```julia
CodeInfo(
1 ─ %1 = (isa)(a, Nothing)::Bool
└──      goto #3 if not %1
2 ─      Base.setfield!(x, :a, nothing)::Nothing
└──      goto #6
3 ─ %5 = (isa)(a, Int64)::Bool
└──      goto #5 if not %5
4 ─ %7 = π (a, Int64)
│        Base.setfield!(x, :a, %7)::Int64
└──      goto #6
5 ─      Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└──      unreachable
6 ┄      return x
)
```
  • Loading branch information
aviatesk authored and LilithHafner committed Feb 22, 2022
1 parent 22feae9 commit ce21911
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 47 deletions.
147 changes: 100 additions & 47 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1147,9 +1147,10 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
return sig
end

# TODO inline non-`isdispatchtuple`, union-split callsites
function analyze_single_call!(
ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8)
(; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState, flag::UInt8)
cases = InliningCase[]
local signature_union = Bottom
local only_method = nothing # keep track of whether there is one matching method
Expand Down Expand Up @@ -1181,7 +1182,7 @@ function analyze_single_call!(
fully_covered = false
continue
end
item = analyze_method!(match, sig.atypes, state, flag)
item = analyze_method!(match, atypes, state, flag)
if item === nothing
fully_covered = false
continue
Expand All @@ -1192,25 +1193,25 @@ function analyze_single_call!(
end
end

signature_fully_covered = sig.atype <: signature_union
# If we're fully covered and there's only one applicable method,
# we inline, even if the signature is not a dispatch tuple
if signature_fully_covered && length(cases) == 0 && only_method isa Method
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
sig.atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp, only_method, true)
else
meth = meth::MethodLookupResult
@assert length(meth) == 1
match = meth[1]
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
if atype <: signature_union
if length(cases) == 0 && only_method isa Method
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp, only_method, true)
else
meth = meth::MethodLookupResult
@assert length(meth) == 1
match = meth[1]
end
item = analyze_method!(match, atypes, state, flag)
item === nothing && return
push!(cases, InliningCase(match.spec_types, item))
fully_covered = true
end
fully_covered = true
item = analyze_method!(match, sig.atypes, state, flag)
item === nothing && return
push!(cases, InliningCase(match.spec_types, item))
end
if !signature_fully_covered
else
fully_covered = false
end

Expand All @@ -1219,36 +1220,81 @@ function analyze_single_call!(
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1].item, false, todo)
return
elseif length(cases) > 0
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
length(cases) == 0 && return
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
return nothing
end

# try to create `InliningCase`s using constant-prop'ed results
# currently it works only when constant-prop' succeeded for all (union-split) signatures
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
function maybe_handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, sig::Signature,
ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature,
state::InliningState, flag::UInt8, isinvoke::Bool, todo::Vector{Pair{Int, Any}})
# when multiple matches are found, bail out and later inliner will union-split this signature
# TODO effectively use multiple constant analysis results here
length(info.results) == 1 || return false
result = info.results[1]
isa(result, InferenceResult) || return false

(; mi) = item = InliningTodo(result, sig.atypes)
validate_sparams(mi.sparam_vals) || return true
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if sig.atype <: mi.def.sig
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
return true
cases = InliningCase[] # TODO avoid this allocation for single cases ?
local fully_covered = true
local signature_union = Bottom
for result in results
isa(result, InferenceResult) || return false
(; mi) = item = InliningTodo(result, atypes)
spec_types = mi.specTypes
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
if !validate_sparams(mi.sparam_vals)
fully_covered = false
continue
end
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
if item === nothing
fully_covered = false
continue
end
push!(cases, InliningCase(spec_types, item))
end

# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
if atype <: signature_union
if length(cases) == 0 && length(results) == 1
(; mi) = item = InliningTodo(results[1]::InferenceResult, atypes)
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
validate_sparams(mi.sparam_vals) || return true
item === nothing && return true
push!(cases, InliningCase(mi.specTypes, item))
fully_covered = true
end
else
item === nothing && return true
# Union split out the error case
item = UnionSplit(false, sig.atype, InliningCase[InliningCase(mi.specTypes, item)])
fully_covered = false
end

# If we only have one case and that case is fully covered, we may either
# be able to do the inlining now (for constant cases), or push it directly
# onto the todo list
if fully_covered && length(cases) == 1
handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo)
elseif length(cases) > 0
isinvoke && rewrite_invoke_exprargs!(stmt)
push!(todo, idx=>item)
return true
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
return true
end

function handle_const_opaque_closure_call!(
ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo,
(; atypes)::Signature, state::InliningState, flag::UInt8, todo::Vector{Pair{Int, Any}})
@assert length(results) == 1
result = results[1]::InferenceResult
item = InliningTodo(result, atypes)
isdispatchtuple(item.mi.specTypes) || return
validate_sparams(item.mi.sparam_vals) || return
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
handle_single_case!(ir, stmt, idx, item, false, todo)
return nothing
end

function assemble_inline_todo!(ir::IRCode, state::InliningState)
Expand Down Expand Up @@ -1283,18 +1329,25 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
# if inference arrived here with constant-prop'ed result(s),
# we can perform a specialized analysis for just this case
if isa(info, ConstCallInfo)
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
ir, idx, stmt, info, sig,
state, flag, sig.f === Core.invoke, todo)
continue
if !is_stmt_noinline(flag)
if isa(info.call, OpaqueClosureCallInfo)
handle_const_opaque_closure_call!(
ir, idx, stmt, info,
sig, state, flag, todo)
continue
else
maybe_handle_const_call!(
ir, idx, stmt, info, sig,
state, flag, sig.f === Core.invoke, todo) && continue
end
else
info = info.call
end
end

if isa(info, OpaqueClosureCallInfo)
result = analyze_method!(info.match, sig.atypes, state, flag)
handle_single_case!(ir, stmt, idx, result, false, todo)
item = analyze_method!(info.match, sig.atypes, state, flag)
handle_single_case!(ir, stmt, idx, item, false, todo)
continue
end

Expand Down
77 changes: 77 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,80 @@ let f(x) = (x...,)
# the the original apply call is not union-split, but the inserted `iterate` call is.
@test code_typed(f, Tuple{Union{Int64, CartesianIndex{1}, CartesianIndex{3}}})[1][2] == Tuple{Int64}
end

# https://github.com/JuliaLang/julia/issues/42754
# inline union-split constant-prop'ed sources
mutable struct X42754
# NOTE in order to confuse `fieldtype_tfunc`, we need to have at least two fields with different types
a::Union{Nothing, Int}
b::Symbol
end
let code = code_typed1((X42754, Union{Nothing,Int})) do x, a
# this `setproperty` call would be union-split and constant-prop will happen for
# each signature: inlining would fail if we don't use constant-prop'ed source
# since the approximate inlining cost of `convert(fieldtype(X, sym), a)` would
# end up very high if we don't propagate `sym::Const(:a)`
x.a = a
x
end
@test all(code) do @nospecialize(x)
isinvoke(x, :setproperty!) && return false
if Meta.isexpr(x, :call)
f = x.args[1]
isa(f, GlobalRef) && f.name === :setproperty! && return false
end
return true
end
end

import Base: @constprop

# test single, non-dispatchtuple callsite inlining

@constprop :none @inline test_single_nondispatchtuple(@nospecialize(t)) =
isa(t, DataType) && t.name === Type.body.name
let
code = code_typed1((Any,)) do x
test_single_nondispatchtuple(x)
end
@test all(code) do @nospecialize(x)
isinvoke(x, :test_single_nondispatchtuple) && return false
if Meta.isexpr(x, :call)
f = x.args[1]
isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false
end
return true
end
end

@constprop :aggressive @inline test_single_nondispatchtuple(c, @nospecialize(t)) =
c && isa(t, DataType) && t.name === Type.body.name
let
code = code_typed1((Any,)) do x
test_single_nondispatchtuple(true, x)
end
@test all(code) do @nospecialize(x)
isinvoke(x, :test_single_nondispatchtuple) && return false
if Meta.isexpr(x, :call)
f = x.args[1]
isa(f, GlobalRef) && f.name === :test_single_nondispatchtuple && return false
end
return true
end
end

# validate inlining processing

@constprop :none @inline validate_unionsplit_inlining(@nospecialize(t)) = throw("invalid inlining processing detected")
@constprop :none @noinline validate_unionsplit_inlining(i::Integer) = (println(IOBuffer(), "prevent inlining"); false)
let
invoke(xs) = validate_unionsplit_inlining(xs[1])
@test invoke(Any[10]) === false
end

@constprop :aggressive @inline validate_unionsplit_inlining(c, @nospecialize(t)) = c && throw("invalid inlining processing detected")
@constprop :aggressive @noinline validate_unionsplit_inlining(c, i::Integer) = c && (println(IOBuffer(), "prevent inlining"); false)
let
invoke(xs) = validate_unionsplit_inlining(true, xs[1])
@test invoke(Any[10]) === false
end

0 comments on commit ce21911

Please sign in to comment.