Skip to content

Commit

Permalink
optimizer: enhance SROA, handle partially-initialized allocations
Browse files Browse the repository at this point in the history
During adding more test cases for our SROA pass, I found our SROA doesn't
handle allocation sites with uninitialized fields at all.
This commit is based on #42833 and tries to handle such "unsafe" allocations,
if there are safe `setfield!` definitions.

For example, this commit allows the allocation `r = Ref{Int}()` to be
eliminated in the following example (adapted from <https://hackmd.io/bZz8k6SHQQuNUW-Vs7rqfw?view>):
```julia
julia> code_typed() do
           r = Ref{Int}()
           r[] = 42
           b = sin(r[])
           return b
       end |> only
```

This commit comes with a plenty of basic test cases for our SROA pass also.
  • Loading branch information
aviatesk committed Oct 28, 2021
1 parent 68f71be commit a9ae9f2
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 40 deletions.
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ function run_passes(ci::CodeInfo, sv::OptimizationState)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
# @timeit "verify 2" verify_ir(ir)
@timeit "compact 2" ir = compact!(ir)
@timeit "SROA" ir = getfield_elim_pass!(ir)
@timeit "SROA" ir = sroa_pass!(ir)
@timeit "ADCE" ir = adce_pass!(ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
Expand Down
80 changes: 45 additions & 35 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,22 @@ function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector
end

function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use_idx::Int)
# Find the first dominating def
def, stmtblock, curblock = find_def_for_use(ir, domtree, allblocks, du, use_idx)
if def == 0
if !haskey(phinodes, curblock)
# If this happens, we need to search the predecessors for defs. Which
# one doesn't matter - if it did, we'd have had a phinode
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
end
# The use is the phinode
return phinodes[curblock]
else
return val_for_def_expr(ir, def, fidx)
end
end

# find the first dominating def for the given use
function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, use_idx::Int)
stmtblock = block_for_inst(ir.cfg, use_idx)
curblock = find_curblock(domtree, allblocks, stmtblock)
local def = 0
Expand All @@ -90,17 +105,7 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end
end
if def == 0
if !haskey(phinodes, curblock)
# If this happens, we need to search the predecessors for defs. Which
# one doesn't matter - if it did, we'd have had a phinode
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
end
# The use is the phinode
return phinodes[curblock]
else
return val_for_def_expr(ir, def, fidx)
end
return def, stmtblock, curblock
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
Expand Down Expand Up @@ -538,7 +543,7 @@ function perform_lifting!(compact::IncrementalCompact,
end

"""
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
sroa_pass!(ir::IRCode) -> newir::IRCode
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
Expand All @@ -555,7 +560,7 @@ its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function getfield_elim_pass!(ir::IRCode)
function sroa_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
Expand Down Expand Up @@ -784,7 +789,6 @@ function getfield_elim_pass!(ir::IRCode)
typ = typ::DataType
# Partition defuses by field
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
ok = true
for use in defuse.uses
stmt = ir[SSAValue(use)]
# We may have discovered above that this use is dead
Expand All @@ -793,47 +797,52 @@ function getfield_elim_pass!(ir::IRCode)
# the use in that case.
stmt === nothing && continue
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
field === nothing && (ok = false; break)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
ok || continue
for use in defuse.defs
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
field === nothing && (ok = false; break)
field === nothing && @goto skip
push!(fielddefuse[field].defs, use)
end
ok || continue
# Check that the defexpr has defined values for all the fields
# we're accessing. In the future, we may want to relax this,
# but we should come up with semantics for well defined semantics
# for uninitialized fields first.
for (fidx, du) in pairs(fielddefuse)
ndefuse = length(fielddefuse)
blocks = Vector{Tuple{#=phiblocks=# Vector{Int}, #=allblocks=# Vector{Int}}}(undef, ndefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
isempty(du.uses) && continue
push!(du.defs, idx)
ldu = compute_live_ins(ir.cfg, du)
phiblocks = Int[]
if !isempty(ldu.live_in_bbs)
phiblocks = idf(ir.cfg, ldu, domtree)
end
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
blocks[fidx] = phiblocks, allblocks
if fidx + 1 > length(defexpr.args)
ok = false
break
for use in du.uses
def = find_def_for_use(ir, domtree, allblocks, du, use)[1]
(def == 0 || def == idx) && @goto skip
end
end
end
ok || continue
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
# Everything accounted for. Go field by field and perform idf
for (fidx, du) in pairs(fielddefuse)
for fidx in 1:ndefuse
du = fielddefuse[fidx]
ftyp = fieldtype(typ, fidx)
if !isempty(du.uses)
push!(du.defs, idx)
ldu = compute_live_ins(ir.cfg, du)
phiblocks = Int[]
if !isempty(ldu.live_in_bbs)
phiblocks = idf(ir.cfg, ldu, domtree)
end
phiblocks, allblocks = blocks[fidx]
phinodes = IdDict{Int, SSAValue}()
for b in phiblocks
n = PhiNode()
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts),
NewInstruction(n, ftyp))
end
# Now go through all uses and rewrite them
allblocks = sort(vcat(phiblocks, ldu.def_bbs))
for stmt in du.uses
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
end
Expand All @@ -855,7 +864,6 @@ function getfield_elim_pass!(ir::IRCode)
stmt == idx && continue
ir[SSAValue(stmt)] = nothing
end
continue
end
isempty(defuse.ccall_preserve_uses) && continue
push!(intermediaries, idx)
Expand All @@ -870,6 +878,8 @@ function getfield_elim_pass!(ir::IRCode)
old_preserves..., new_preserves...)
ir[SSAValue(use)] = new_expr
end

@label skip
end

return ir
Expand Down Expand Up @@ -919,14 +929,14 @@ In addition to a simple DCE for unused values and allocations,
this pass also nullifies `typeassert` calls that can be proved to be no-op,
in order to allow LLVM to emit simpler code down the road.
Note that this pass is more effective after SROA optimization (i.e. `getfield_elim_pass!`),
Note that this pass is more effective after SROA optimization (i.e. `sroa_pass!`),
since SROA often allows this pass to:
- eliminate allocation of object whose field references are all replaced with scalar values, and
- nullify `typeassert` call whose first operand has been replaced with a scalar value
(, which may have introduced new type information that inference did not understand)
Also note that currently this pass _needs_ to run after `getfield_elim_pass!`, because
the `typeassert` elimination depends on the transformation within `getfield_elim_pass!`
Also note that currently this pass _needs_ to run after `sroa_pass!`, because
the `typeassert` elimination depends on the transformation within `sroa_pass!`
which redirects references of `typeassert`ed value to the corresponding `PiNode`.
"""
function adce_pass!(ir::IRCode)
Expand Down
4 changes: 2 additions & 2 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ using Base.Experimental: @opaque
f_oc_getfield(x) = (@opaque ()->x)()
@test fully_eliminated(f_oc_getfield, Tuple{Int})

import Core.Compiler: argextype
import Core.Compiler: argextype, singleton_type
const EMPTY_SPTYPES = Core.Compiler.EMPTY_SLOTTYPES

code_typed1(args...; kwargs...) = first(only(code_typed(args...; kwargs...)))::Core.CodeInfo
Expand All @@ -389,7 +389,7 @@ get_code(args...; kwargs...) = code_typed1(args...; kwargs...).code
# check if `x` is a dynamic call of a given function
function iscall((src, f)::Tuple{Core.CodeInfo,Function}, @nospecialize(x))
return iscall(x) do @nospecialize x
argextype(x, src, EMPTY_SPTYPES) === typeof(f)
singleton_type(argextype(x, src, EMPTY_SPTYPES)) === f
end
end
iscall(pred, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1])
Expand Down
Loading

0 comments on commit a9ae9f2

Please sign in to comment.