Skip to content

Commit

Permalink
Remove union penalties for inlining cost
Browse files Browse the repository at this point in the history
I added this code back in #27057, when I first made Union-full
signatures inlineable. The justification was to try to encourage
the union splitting to happen on the outside. However (and I believe
this changed since this code was introduced), these days inference
is in complete control of union splitting and we do not take
inlineability or non-inlineability of the non-unionsplit function
into account when deciding how to inline. As a result, the only
effect of the union split penalties was to prevent inlining of
functions that are not union-split eligible (e.g.
`+(::Vararg{Union{Int, Missing}, 3})`), but are nevertheless cheap
by our inlining metric. There is really no reason not to try to
inline such functions, so delete this logic.
  • Loading branch information
Keno committed Jul 7, 2023
1 parent d9ad6d2 commit eb729af
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 31 deletions.
36 changes: 10 additions & 26 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,18 +406,9 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
opt.ir = ir

# determine and cache inlineability
union_penalties = false
if !force_noinline
sig = unwrap_unionall(specTypes)
if isa(sig, DataType) && sig.name === Tuple.name
for P in sig.parameters
P = unwrap_unionall(P)
if isa(P, Union)
union_penalties = true
break
end
end
else
if !(isa(sig, DataType) && sig.name === Tuple.name)
force_noinline = true
end
if !is_declared_inline(src) && result === Bottom
Expand Down Expand Up @@ -448,7 +439,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
cost_threshold += 4*default
end
end
src.inlining_cost = inline_cost(ir, params, union_penalties, cost_threshold)
src.inlining_cost = inline_cost(ir, params, cost_threshold)
end
end
return nothing
Expand Down Expand Up @@ -645,7 +636,7 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))

function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
params::OptimizationParams, error_path::Bool = false)
head = ex.head
if is_meta_expr_head(head)
return 0
Expand Down Expand Up @@ -683,13 +674,6 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
return 1
elseif f === Core.isa
# If we're in a union context, we penalize type computations
# on union types. In such cases, it is usually better to perform
# union splitting on the outside.
if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union)
return params.inline_nonleaf_penalty
end
end
fidx = find_tfunc(f)
if fidx === nothing
Expand Down Expand Up @@ -720,7 +704,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end
a = ex.args[2]
if a isa Expr
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, params, error_path))
end
return cost
elseif head === :copyast
Expand All @@ -736,11 +720,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end

function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState},
union_penalties::Bool, params::OptimizationParams)
params::OptimizationParams)
thiscost = 0
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
thiscost = statement_cost(stmt, line, src, sptypes, params,
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
elseif stmt isa GotoNode
# loops are generally always expensive
Expand All @@ -753,24 +737,24 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{Cod
return thiscost
end

function inline_cost(ir::IRCode, params::OptimizationParams, union_penalties::Bool=false,
function inline_cost(ir::IRCode, params::OptimizationParams,
cost_threshold::Integer=params.inline_cost_threshold)::InlineCostType
bodycost::Int = 0
for line = 1:length(ir.stmts)
stmt = ir.stmts[line][:inst]
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, params)
bodycost = plus_saturate(bodycost, thiscost)
bodycost > cost_threshold && return MAX_INLINE_COST
end
return inline_cost_clamp(bodycost)
end

function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, unionpenalties::Bool, params::OptimizationParams)
function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{VarState}, params::OptimizationParams)
maxcost = 0
for line = 1:length(body)
stmt = body[line]
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
unionpenalties, params)
params)
cost[line] = thiscost
if thiscost > maxcost
maxcost = thiscost
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1669,7 +1669,7 @@ function print_statement_costs(io::IO, @nospecialize(tt::Type);
empty!(cst)
resize!(cst, length(code.code))
sptypes = Core.Compiler.VarState[Core.Compiler.VarState(sp, false) for sp in match.sparams]
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, false, params)
maxcost = Core.Compiler.statement_costs!(cst, code.code, code, sptypes, params)
nd = ndigits(maxcost)
irshow_config = IRShow.IRShowConfig() do io, linestart, idx
print(io, idx > 0 ? lpad(cst[idx], nd+1) : " "^(nd+1), " ")
Expand Down
8 changes: 4 additions & 4 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -627,15 +627,15 @@ end
B = OffsetArray(reshape(1:24, 4, 3, 2), -5, 6, -7)
for R in (fill(0, -4:-1), fill(0, -4:-1, 7:7), fill(0, -4:-1, 7:7, -6:-6))
@test @inferred(maximum!(R, B)) == reshape(maximum(B, dims=(2,3)), axes(R)) == reshape(21:24, axes(R))
@test @allocated(maximum!(R, B)) <= 800
@test @allocated(maximum!(R, B)) <= 1300
@test @inferred(minimum!(R, B)) == reshape(minimum(B, dims=(2,3)), axes(R)) == reshape(1:4, axes(R))
@test @allocated(minimum!(R, B)) <= 800
@test @allocated(minimum!(R, B)) <= 1300
end
for R in (fill(0, -4:-4, 7:9), fill(0, -4:-4, 7:9, -6:-6))
@test @inferred(maximum!(R, B)) == reshape(maximum(B, dims=(1,3)), axes(R)) == reshape(16:4:24, axes(R))
@test @allocated(maximum!(R, B)) <= 800
@test @allocated(maximum!(R, B)) <= 1300
@test @inferred(minimum!(R, B)) == reshape(minimum(B, dims=(1,3)), axes(R)) == reshape(1:4:9, axes(R))
@test @allocated(minimum!(R, B)) <= 800
@test @allocated(minimum!(R, B)) <= 1300
end
@test_throws DimensionMismatch maximum!(fill(0, -4:-1, 7:7, -6:-6, 1:1), B)
@test_throws DimensionMismatch minimum!(fill(0, -4:-1, 7:7, -6:-6, 1:1), B)
Expand Down

0 comments on commit eb729af

Please sign in to comment.