Skip to content

Commit

Permalink
inlining: add missing late special handling for UnionAll method call (
Browse files Browse the repository at this point in the history
JuliaLang#43479)

Looking at the result of <JuliaLang#43452 (comment)>,
I found that currently the inlinear sometimes fails to handle `UnionAll`
call (e.g. runtime dispatch detected: Core.UnionAll(%28::TypeVar, %29::Any)).

This commit adds a missing late special handling for `UnionAll` calls:

> before
```julia
julia> code_typed((TypeVar,)) do tv
           UnionAll(tv, Type{tv})
       end
1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = Core.apply_type(Main.Type, tv)::Type{Type{_A}} where _A
│   %2 = Main.UnionAll(tv, %1)::Any
└──      return %2
) => Any
```

> after
```julia
julia> code_typed((TypeVar,)) do tv
           UnionAll(tv, Type{tv})
       end
1-element Vector{Any}:
 CodeInfo(
1 ─ %1 = Core.apply_type(Main.Type, tv)::Type{Type{_A}} where _A
│   %2 = $(Expr(:foreigncall, :(:jl_type_unionall), Any, svec(Any, Any), 0, :(:ccall), Core.Argument(2), :(%1)))::Any
└──      return %2
) => Any
```
  • Loading branch information
aviatesk authored and LilithHafner committed Feb 22, 2022
1 parent 3fb337c commit 9eb8c16
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
74 changes: 45 additions & 29 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,29 +332,27 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
if coverage && spec.ir.stmts[1][:line] + linetable_offset != topline
insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), Nothing, topline))
end
nargs_def = def.nargs::Int32
isva = nargs_def > 0 && def.isva
sig = def.sig
if isva
vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], topline)
argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg]
if def.isva
nargs_def = Int(def.nargs::Int32)
if nargs_def > 0
argexprs = fix_va_argexprs!(compact, argexprs, nargs_def, topline)
end
end
if def.is_for_opaque_closure
# Replace the first argument by a load of the capture environment
argexprs[1] = insert_node_here!(compact,
NewInstruction(Expr(:call, GlobalRef(Core, :getfield), argexprs[1], QuoteNode(:captures)),
spec.ir.argtypes[1], topline))
end
flag = compact.result[idx][:flag]
boundscheck_idx = boundscheck
if boundscheck_idx === :default || boundscheck_idx === :propagate
if (flag & IR_FLAG_INBOUNDS) != 0
boundscheck_idx = :off
if boundscheck === :default || boundscheck === :propagate
if (compact.result[idx][:flag] & IR_FLAG_INBOUNDS) != 0
boundscheck = :off
end
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
if spec.linear_inline_eligible
#compact[idx] = nothing
Expand All @@ -364,7 +362,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
isa(val, SSAValue) && (compact.used_ssas[val.id] += 1)
Expand All @@ -390,7 +388,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -441,6 +439,24 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
return_value
end

function fix_va_argexprs!(compact::IncrementalCompact,
argexprs::Vector{Any}, nargs_def::Int, line_idx::Int32)
newargexprs = Any[]
for i in 1:(nargs_def-1)
push!(newargexprs, argexprs[i])
end
tuple_call = Expr(:call, TOP_TUPLE)
tuple_typs = Any[]
for i in nargs_def:length(argexprs)
arg = argexprs[i]
push!(tuple_call.args, arg)
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
return newargexprs
end

const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)")

function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
Expand Down Expand Up @@ -482,11 +498,12 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
if !isa(case, ConstantCase)
argexprs′ = copy(argexprs)
for i = 1:length(mparams)
argex = argexprs[i]
(isa(argex, SSAValue) || isa(argex, Argument)) || continue
a, m = aparams[i], mparams[i]
(isa(argexprs[i], SSAValue) || isa(argexprs[i], Argument)) || continue
if !(a <: m)
argexprs′[i] = insert_node_here!(compact,
NewInstruction(PiNode(argexprs′[i], m), m, line))
NewInstruction(PiNode(argex, m), m, line))
end
end
end
Expand Down Expand Up @@ -1055,18 +1072,17 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
check_effect_free!(ir, idx, stmt, rt)
return nothing
end
if stmt.head !== :call
if stmt.head === :splatnew
head = stmt.head
if head !== :call
if head === :splatnew
inline_splatnew!(ir, idx, stmt, rt)
elseif stmt.head === :new_opaque_closure
elseif head === :new_opaque_closure
narrow_opaque_closure!(ir, stmt, ir.stmts[idx][:info], state)
end
check_effect_free!(ir, idx, stmt, rt)
return nothing
end

stmt.head === :call || return nothing

sig = call_sig(ir, stmt)
sig === nothing && return nothing

Expand Down Expand Up @@ -1286,8 +1302,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
end
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
info = info.info
end
if info === false
elseif info === false
# Inference determined this couldn't be analyzed. Don't question it.
continue
end
Expand Down Expand Up @@ -1330,20 +1345,14 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
elseif isa(info, UnionSplitInfo)
infos = info.matches
else
continue
continue # isa(info, ReturnTypeCallInfo), etc.
end

analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo)
end
todo
end

function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32)
e = Expr(:call, TOP_TUPLE, args...)
etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)])
return insert_node_here!(compact, NewInstruction(e, etyp, line_idx))
end

function linear_inline_eligible(ir::IRCode)
length(ir.cfg.blocks) == 1 || return false
terminator = ir[SSAValue(last(ir.cfg.blocks[1].stmts))]
Expand Down Expand Up @@ -1395,6 +1404,9 @@ function early_inline_special_case(
return nothing
end

# special-case some regular method calls whose results are not folded within `abstract_call_known`
# (and thus `early_inline_special_case` doesn't handle them yet)
# NOTE we manually inline the method bodies, and so the logic here needs to precisely sync with their definitions
function late_inline_special_case!(
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(type), sig::Signature,
params::OptimizationParams)
Expand Down Expand Up @@ -1423,6 +1435,10 @@ function late_inline_special_case!(
length(stmt.args) < 4 ? Bottom : stmt.args[3],
length(stmt.args) == 2 ? Any : stmt.args[end])
return SomeCase(typevar_call)
elseif isinlining && f === UnionAll && length(argtypes) == 3 && (argtypes[2] TypeVar)
unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any),
0, QuoteNode(:ccall), stmt.args[2], stmt.args[3])
return SomeCase(unionall_call)
elseif is_return_type(f)
if isconstType(type)
return SomeCase(quoted(type.parameters[1]))
Expand Down
18 changes: 16 additions & 2 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,12 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code

# check if `x` is a dynamic call of a given function
iscall(y) = @nospecialize(x) -> iscall(y, x)
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x))
return iscall(x) do @nospecialize x
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
end
end
iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])

# check if `x` is a statically-resolved call of a function whose name is `sym`
isinvoke(y) = @nospecialize(x) -> isinvoke(y, x)
Expand Down Expand Up @@ -860,3 +860,17 @@ let # aggressive static dispatch of single, abstract method match
# `checkBadType!(y::Any)` isn't fully covered, thus a runtime type check and fallback dynamic dispatch should be inserted
@test count(iscall((src,checkBadType!)), src.code) == 1
end

@testset "late_inline_special_case!" begin
let src = code_typed((Symbol,Any,Any)) do a, b, c
TypeVar(a, b, c)
end |> only |> first
@test count(iscall((src,TypeVar)), src.code) == 0
@test count(iscall((src,Core._typevar)), src.code) == 1
end
let src = code_typed((TypeVar,Any)) do a, b
UnionAll(a, b)
end |> only |> first
@test count(iscall((src,UnionAll)), src.code) == 0
end
end

0 comments on commit 9eb8c16

Please sign in to comment.