Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol committed May 12, 2022
1 parent fa66298 commit 1c585d1
Showing 1 changed file with 48 additions and 68 deletions.
116 changes: 48 additions & 68 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
end
end

# the duplicated check might have been done already within `analyze_method!`, but still
# the duplicated check might have been done already within `analyze_method`, but still
# we need it here too since we may come here directly using a constant-prop' result
if !state.params.inlining || is_stmt_noinline(flag)
return compileable_specialization(et, match, effects)
Expand Down Expand Up @@ -891,7 +891,7 @@ function validate_sparams(sparams::SimpleVector)
return true
end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
function analyze_method(match::MethodMatch, argtypes::Vector{Any},
flag::UInt8, state::InliningState, check_sparams::Bool = false)
method = match.method
spec_types = match.spec_types
Expand Down Expand Up @@ -1128,7 +1128,7 @@ function inline_invoke!(
return nothing
end
end
item = analyze_method!(match, argtypes, flag, state)
item = analyze_method(match, argtypes, flag, state)
end
handle_single_case!(ir, idx, stmt, item, todo, state.params, true)
return nothing
Expand Down Expand Up @@ -1247,12 +1247,34 @@ function analyze_single_call!(
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
argtypes = sig.argtypes
res = analyze_infos(infos, argtypes, flag, state, nothing)
res === nothing && return res
cases, fully_covered = res
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, fully_covered, todo, state.params)
end

# similar to `analyze_single_call!`, but with constant results
function handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
argtypes = sig.argtypes
call = cinfo.call
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
res = analyze_infos(infos, argtypes, flag, state, cinfo)
res === nothing && return res
cases, fully_covered = res
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, fully_covered, todo, state.params)
end

function analyze_infos(infos, argtypes, flag::UInt8, state::InliningState, cinfo::Union{Nothing, ConstCallInfo})
const_results = cinfo !== nothing ? cinfo.results : nothing
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local revisit_idx = nothing
local only_method = nothing
local only_method = nothing # tri-valued: nothing if unknown, false if proven untrue, otherwise the method itself
local meth::MethodLookupResult
local k = 0 # indexes cinfo.results for the result corresponding to matches in each method
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1275,19 +1297,21 @@ function analyze_single_call!(
end
end
for (j, match) in enumerate(meth)
k += 1
if !validate_sparams(match.sparams)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
revisit_idx = (i, j, k) # `match` and `result` are info[i].results[j] and cinfo.results[k], respectively
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
result = const_results !== nothing ? const_results[k] : nothing
handled_all_cases &= handle_match!(result, match, argtypes, flag, state, cases, true, true)
end
any_fully_covered |= match.fully_covers
end
Expand All @@ -1302,9 +1326,10 @@ function analyze_single_call!(
# foo(@nospecialize(x::Any)) = 2
# where we where only a small number of specific dispatchable
# cases are split off from an ::Any typed fallback.
(i, j) = revisit_idx
(i, j, k) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true, true)
result = const_results !== nothing ? const_results[k] : nothing
handled_all_cases &= handle_match!(result, match, argtypes, flag, state, cases, true, true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple.
Expand All @@ -1325,66 +1350,21 @@ function analyze_single_call!(
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
return cases, handled_all_cases & any_fully_covered
end

# similar to `analyze_single_call!`, but with constant results
function handle_const_call!(
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
argtypes = sig.argtypes
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local j = 0
local only_method = nothing # tri-valued: nothing if unknown, false if proven untrue, otherwise the method itself
for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
return nothing
elseif length(meth) == 0
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end
for match in meth
j += 1
result = results[j]
any_fully_covered |= match.fully_covers
if isa(result, ConcreteResult)
case = concrete_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, ConstPropResult)
handled_all_cases &= handle_const_prop_result!(result, argtypes, flag, state, cases, true, true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true, true)
end
end
end

if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
function handle_match!(const_result::Union{Nothing, ConstPropResult, ConcreteResult},
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false, check_sparams::Bool = false)
if const_result === nothing
return handle_match!(match, argtypes, flag, state, cases, allow_abstract, check_sparams)
elseif isa(const_result, ConstPropResult)
return handle_const_prop_result!(const_result, argtypes, flag, state, cases, allow_abstract, check_sparams)
elseif isa(const_result, ConcreteResult)
case = concrete_result_item(const_result, state)
push!(cases, InliningCase(const_result.mi.specTypes, case))
return true
end

handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
end

function handle_match!(
Expand All @@ -1397,9 +1377,9 @@ function handle_match!(
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, flag, state, true)
item = analyze_method(match, argtypes, flag, state, true)
else
item = analyze_method!(match, argtypes, flag, state)
item = analyze_method(match, argtypes, flag, state)
end
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
Expand Down Expand Up @@ -1498,7 +1478,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
if isa(result, ConcreteResult)
item = concrete_result_item(result, state)
else
item = analyze_method!(info.match, sig.argtypes, flag, state)
item = analyze_method(info.match, sig.argtypes, flag, state)
end
handle_single_case!(ir, idx, stmt, item, todo, state.params)
end
Expand Down

0 comments on commit 1c585d1

Please sign in to comment.