Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inlining: add missing late special handling for UnionAll method call #43479

Merged
merged 2 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 45 additions & 29 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,29 +332,27 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
if coverage && spec.ir.stmts[1][:line] + linetable_offset != topline
insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), Nothing, topline))
end
nargs_def = def.nargs::Int32
isva = nargs_def > 0 && def.isva
sig = def.sig
if isva
vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], topline)
argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg]
if def.isva
nargs_def = Int(def.nargs::Int32)
if nargs_def > 0
argexprs = fix_va_argexprs!(compact, argexprs, nargs_def, topline)
end
end
if def.is_for_opaque_closure
# Replace the first argument by a load of the capture environment
argexprs[1] = insert_node_here!(compact,
NewInstruction(Expr(:call, GlobalRef(Core, :getfield), argexprs[1], QuoteNode(:captures)),
spec.ir.argtypes[1], topline))
end
flag = compact.result[idx][:flag]
boundscheck_idx = boundscheck
if boundscheck_idx === :default || boundscheck_idx === :propagate
if (flag & IR_FLAG_INBOUNDS) != 0
boundscheck_idx = :off
if boundscheck === :default || boundscheck === :propagate
if (compact.result[idx][:flag] & IR_FLAG_INBOUNDS) != 0
boundscheck = :off
end
end
# If the iterator already moved on to the next basic block,
# temporarily re-open in again.
local return_value
sig = def.sig
# Special case inlining that maintains the current basic block if there's only one BB in the target
if spec.linear_inline_eligible
#compact[idx] = nothing
Expand All @@ -364,7 +362,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
if isa(stmt′, ReturnNode)
val = stmt′.val
isa(val, SSAValue) && (compact.used_ssas[val.id] += 1)
Expand All @@ -390,7 +388,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -441,6 +439,24 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
return_value
end

function fix_va_argexprs!(compact::IncrementalCompact,
argexprs::Vector{Any}, nargs_def::Int, line_idx::Int32)
newargexprs = Any[]
for i in 1:(nargs_def-1)
push!(newargexprs, argexprs[i])
end
Comment on lines +444 to +447
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use such an awkward way to slice this array?

Suggested change
newargexprs = Any[]
for i in 1:(nargs_def-1)
push!(newargexprs, argexprs[i])
end
newargexprs = argexprs[1:(nargs_def-1)]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuple_call = Expr(:call, TOP_TUPLE)
tuple_typs = Any[]
for i in nargs_def:length(argexprs)
arg = argexprs[i]
push!(tuple_call.args, arg)
push!(tuple_typs, argextype(arg, compact))
end
tuple_typ = tuple_tfunc(tuple_typs)
push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx)))
return newargexprs
end

const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)")

function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
Expand Down Expand Up @@ -482,11 +498,12 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
if !isa(case, ConstantCase)
argexprs′ = copy(argexprs)
for i = 1:length(mparams)
argex = argexprs[i]
(isa(argex, SSAValue) || isa(argex, Argument)) || continue
a, m = aparams[i], mparams[i]
(isa(argexprs[i], SSAValue) || isa(argexprs[i], Argument)) || continue
if !(a <: m)
argexprs′[i] = insert_node_here!(compact,
NewInstruction(PiNode(argexprs′[i], m), m, line))
NewInstruction(PiNode(argex, m), m, line))
end
end
end
Expand Down Expand Up @@ -1055,18 +1072,17 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
check_effect_free!(ir, idx, stmt, rt)
return nothing
end
if stmt.head !== :call
if stmt.head === :splatnew
head = stmt.head
if head !== :call
if head === :splatnew
inline_splatnew!(ir, idx, stmt, rt)
elseif stmt.head === :new_opaque_closure
elseif head === :new_opaque_closure
narrow_opaque_closure!(ir, stmt, ir.stmts[idx][:info], state)
end
check_effect_free!(ir, idx, stmt, rt)
return nothing
end

stmt.head === :call || return nothing

sig = call_sig(ir, stmt)
sig === nothing && return nothing

Expand Down Expand Up @@ -1286,8 +1302,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
end
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
info = info.info
end
if info === false
elseif info === false
# Inference determined this couldn't be analyzed. Don't question it.
continue
end
Expand Down Expand Up @@ -1330,20 +1345,14 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
elseif isa(info, UnionSplitInfo)
infos = info.matches
else
continue
continue # isa(info, ReturnTypeCallInfo), etc.
end

analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo)
end
todo
end

function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32)
e = Expr(:call, TOP_TUPLE, args...)
etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)])
return insert_node_here!(compact, NewInstruction(e, etyp, line_idx))
end

function linear_inline_eligible(ir::IRCode)
length(ir.cfg.blocks) == 1 || return false
terminator = ir[SSAValue(last(ir.cfg.blocks[1].stmts))]
Expand Down Expand Up @@ -1395,6 +1404,9 @@ function early_inline_special_case(
return nothing
end

# special-case some regular method calls whose results are not folded within `abstract_call_known`
# (and thus `early_inline_special_case` doesn't handle them yet)
# NOTE we manually inline the method bodies, and so the logic here needs to precisely sync with their definitions
function late_inline_special_case!(
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(type), sig::Signature,
params::OptimizationParams)
Expand Down Expand Up @@ -1423,6 +1435,10 @@ function late_inline_special_case!(
length(stmt.args) < 4 ? Bottom : stmt.args[3],
length(stmt.args) == 2 ? Any : stmt.args[end])
return SomeCase(typevar_call)
elseif isinlining && f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar)
unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any),
0, QuoteNode(:ccall), stmt.args[2], stmt.args[3])
return SomeCase(unionall_call)
elseif is_return_type(f)
if isconstType(type)
return SomeCase(quoted(type.parameters[1]))
Expand Down
18 changes: 16 additions & 2 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -388,12 +388,12 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code

# check if `x` is a dynamic call of a given function
iscall(y) = @nospecialize(x) -> iscall(y, x)
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x))
return iscall(x) do @nospecialize x
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
end
end
iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])

# check if `x` is a statically-resolved call of a function whose name is `sym`
isinvoke(y) = @nospecialize(x) -> isinvoke(y, x)
Expand Down Expand Up @@ -860,3 +860,17 @@ let # aggressive static dispatch of single, abstract method match
# `checkBadType!(y::Any)` isn't fully covered, thus a runtime type check and fallback dynamic dispatch should be inserted
@test count(iscall((src,checkBadType!)), src.code) == 1
end

@testset "late_inline_special_case!" begin
let src = code_typed((Symbol,Any,Any)) do a, b, c
TypeVar(a, b, c)
end |> only |> first
@test count(iscall((src,TypeVar)), src.code) == 0
@test count(iscall((src,Core._typevar)), src.code) == 1
end
let src = code_typed((TypeVar,Any)) do a, b
UnionAll(a, b)
end |> only |> first
@test count(iscall((src,UnionAll)), src.code) == 0
end
end