From d959b4373adf3843483b89808475b3b14ec4a708 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sun, 19 Dec 2021 17:26:17 +0900 Subject: [PATCH 1/2] inlining: add missing late special handling for `UnionAll` method call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Looking at the result of , I found that currently the inlinear sometimes fails to handle `UnionAll` call (e.g. runtime dispatch detected: Core.UnionAll(%28::TypeVar, %29::Any)). This commit adds a missing late special handling for `UnionAll` calls: > before ```julia julia> code_typed((TypeVar,)) do tv UnionAll(tv, Type{tv}) end 1-element Vector{Any}: CodeInfo( 1 ─ %1 = Core.apply_type(Main.Type, tv)::Type{Type{_A}} where _A │ %2 = Main.UnionAll(tv, %1)::Any └── return %2 ) => Any ``` > after ```julia julia> code_typed((TypeVar,)) do tv UnionAll(tv, Type{tv}) end 1-element Vector{Any}: CodeInfo( 1 ─ %1 = Core.apply_type(Main.Type, tv)::Type{Type{_A}} where _A │ %2 = $(Expr(:foreigncall, :(:jl_type_unionall), Any, svec(Any, Any), 0, :(:ccall), Core.Argument(2), :(%1)))::Any └── return %2 ) => Any ``` --- base/compiler/ssair/inlining.jl | 21 +++++++++++++-------- test/compiler/inline.jl | 18 ++++++++++++++++-- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 26dd2aa958ca6..3971d9d01bbfb 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1055,18 +1055,17 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto check_effect_free!(ir, idx, stmt, rt) return nothing end - if stmt.head !== :call - if stmt.head === :splatnew + head = stmt.head + if head !== :call + if head === :splatnew inline_splatnew!(ir, idx, stmt, rt) - elseif stmt.head === :new_opaque_closure + elseif head === :new_opaque_closure narrow_opaque_closure!(ir, stmt, ir.stmts[idx][:info], state) end check_effect_free!(ir, idx, stmt, rt) return nothing end - stmt.head === :call || return nothing - sig = call_sig(ir, stmt) sig === nothing && return nothing @@ -1286,8 +1285,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) end ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE info = info.info - end - if info === false + elseif info === false # Inference determined this couldn't be analyzed. Don't question it. continue end @@ -1330,7 +1328,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) elseif isa(info, UnionSplitInfo) infos = info.matches else - continue + continue # isa(info, ReturnTypeCallInfo), etc. end analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo) @@ -1395,6 +1393,9 @@ function early_inline_special_case( return nothing end +# special-case some regular method calls whose results are not folded within `abstract_call_known` +# (and thus `early_inline_special_case` doesn't handle them yet) +# NOTE we manually inline the method bodies, and so the logic here needs to precisely sync with their definitions function late_inline_special_case!( ir::IRCode, idx::Int, stmt::Expr, @nospecialize(type), sig::Signature, params::OptimizationParams) @@ -1423,6 +1424,10 @@ function late_inline_special_case!( length(stmt.args) < 4 ? Bottom : stmt.args[3], length(stmt.args) == 2 ? Any : stmt.args[end]) return SomeCase(typevar_call) + elseif isinlining && f === UnionAll && length(argtypes) == 3 && (argtypes[2] ⊑ TypeVar) + unionall_call = Expr(:foreigncall, QuoteNode(:jl_type_unionall), Any, svec(Any, Any), + 0, QuoteNode(:ccall), stmt.args[2], stmt.args[3]) + return SomeCase(unionall_call) elseif is_return_type(f) if isconstType(type) return SomeCase(quoted(type.parameters[1])) diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 83780ca8b1ac5..111ced17fd54a 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -388,12 +388,12 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code # check if `x` is a dynamic call of a given function iscall(y) = @nospecialize(x) -> iscall(y, x) -function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x)) +function iscall((src, f)::Tuple{Core.CodeInfo,Base.Callable}, @nospecialize(x)) return iscall(x) do @nospecialize x singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f end end -iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) +iscall(pred::Base.Callable, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) # check if `x` is a statically-resolved call of a function whose name is `sym` isinvoke(y) = @nospecialize(x) -> isinvoke(y, x) @@ -860,3 +860,17 @@ let # aggressive static dispatch of single, abstract method match # `checkBadType!(y::Any)` isn't fully covered, thus a runtime type check and fallback dynamic dispatch should be inserted @test count(iscall((src,checkBadType!)), src.code) == 1 end + +@testset "late_inline_special_case!" begin + let src = code_typed((Symbol,Any,Any)) do a, b, c + TypeVar(a, b, c) + end |> only |> first + @test count(iscall((src,TypeVar)), src.code) == 0 + @test count(iscall((src,Core._typevar)), src.code) == 1 + end + let src = code_typed((TypeVar,Any)) do a, b + UnionAll(a, b) + end |> only |> first + @test count(iscall((src,UnionAll)), src.code) == 0 + end +end From e2810a0fcd8be066c8102937656825557767b8b1 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Sun, 19 Dec 2021 23:38:58 +0900 Subject: [PATCH 2/2] inlining: minor optimizations --- base/compiler/ssair/inlining.jl | 53 ++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 3971d9d01bbfb..5ad897410d822 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -332,12 +332,11 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector if coverage && spec.ir.stmts[1][:line] + linetable_offset != topline insert_node_here!(compact, NewInstruction(Expr(:code_coverage_effect), Nothing, topline)) end - nargs_def = def.nargs::Int32 - isva = nargs_def > 0 && def.isva - sig = def.sig - if isva - vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], topline) - argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg] + if def.isva + nargs_def = Int(def.nargs::Int32) + if nargs_def > 0 + argexprs = fix_va_argexprs!(compact, argexprs, nargs_def, topline) + end end if def.is_for_opaque_closure # Replace the first argument by a load of the capture environment @@ -345,16 +344,15 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector NewInstruction(Expr(:call, GlobalRef(Core, :getfield), argexprs[1], QuoteNode(:captures)), spec.ir.argtypes[1], topline)) end - flag = compact.result[idx][:flag] - boundscheck_idx = boundscheck - if boundscheck_idx === :default || boundscheck_idx === :propagate - if (flag & IR_FLAG_INBOUNDS) != 0 - boundscheck_idx = :off + if boundscheck === :default || boundscheck === :propagate + if (compact.result[idx][:flag] & IR_FLAG_INBOUNDS) != 0 + boundscheck = :off end end # If the iterator already moved on to the next basic block, # temporarily re-open in again. local return_value + sig = def.sig # Special case inlining that maintains the current basic block if there's only one BB in the target if spec.linear_inline_eligible #compact[idx] = nothing @@ -364,7 +362,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector # face of rename_arguments! mutating in place - should figure out # something better eventually. inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) if isa(stmt′, ReturnNode) val = stmt′.val isa(val, SSAValue) && (compact.used_ssas[val.id] += 1) @@ -390,7 +388,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx) for ((_, idx′), stmt′) in inline_compact inline_compact[idx′] = nothing - stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact) + stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck, compact) if isa(stmt′, ReturnNode) if isdefined(stmt′, :val) val = stmt′.val @@ -441,6 +439,24 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector return_value end +function fix_va_argexprs!(compact::IncrementalCompact, + argexprs::Vector{Any}, nargs_def::Int, line_idx::Int32) + newargexprs = Any[] + for i in 1:(nargs_def-1) + push!(newargexprs, argexprs[i]) + end + tuple_call = Expr(:call, TOP_TUPLE) + tuple_typs = Any[] + for i in nargs_def:length(argexprs) + arg = argexprs[i] + push!(tuple_call.args, arg) + push!(tuple_typs, argextype(arg, compact)) + end + tuple_typ = tuple_tfunc(tuple_typs) + push!(newargexprs, insert_node_here!(compact, NewInstruction(tuple_call, tuple_typ, line_idx))) + return newargexprs +end + const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)") function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, @@ -482,11 +498,12 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, if !isa(case, ConstantCase) argexprs′ = copy(argexprs) for i = 1:length(mparams) + argex = argexprs[i] + (isa(argex, SSAValue) || isa(argex, Argument)) || continue a, m = aparams[i], mparams[i] - (isa(argexprs[i], SSAValue) || isa(argexprs[i], Argument)) || continue if !(a <: m) argexprs′[i] = insert_node_here!(compact, - NewInstruction(PiNode(argexprs′[i], m), m, line)) + NewInstruction(PiNode(argex, m), m, line)) end end end @@ -1336,12 +1353,6 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) todo end -function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32) - e = Expr(:call, TOP_TUPLE, args...) - etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)]) - return insert_node_here!(compact, NewInstruction(e, etyp, line_idx)) -end - function linear_inline_eligible(ir::IRCode) length(ir.cfg.blocks) == 1 || return false terminator = ir[SSAValue(last(ir.cfg.blocks[1].stmts))]