Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [NewOptimizer] The one SROA pass to rule them all #26778

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ include("compiler/ssair/passes.jl")
include("compiler/ssair/inlining2.jl")
include("compiler/ssair/verify.jl")
include("compiler/ssair/legacy.jl")
@isdefined(Base) && include("compiler/ssair/show.jl")
#@isdefined(Base) && include("compiler/ssair/show.jl")

function normalize_expr(stmt::Expr)
if stmt.head === :gotoifnot
Expand Down Expand Up @@ -165,8 +165,9 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode}, s
@timeit "Inlining" ir = ssa_inlining_pass!(ir, linetable, sv)
#@timeit "verify 2" verify_ir(ir)
@timeit "domtree 2" domtree = construct_domtree(ir.cfg)
ir = compact!(ir)
@timeit "SROA" ir = getfield_elim_pass!(ir, domtree)
@timeit "compact 2" ir = compact!(ir)
ir = adce_pass!(ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
#@Base.show ir
Expand Down
246 changes: 204 additions & 42 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ mutable struct IncrementalCompact
# This could be Stateful, but bootstrapping doesn't like that
perm::Vector{Int}
new_nodes_idx::Int
# This supports insertion while compacting
new_new_nodes::Vector{NewNode} # New nodes that were before the compaction point at insertion time
# TODO: Switch these two to a min-heap of some sort
pending_nodes::Vector{NewNode} # New nodes that were after the compaction point at insertion time
pending_perm::Vector{Int}
# State
idx::Int
result_idx::Int
active_result_bb::Int
Expand All @@ -427,7 +433,12 @@ mutable struct IncrementalCompact
used_ssas = fill(0, new_len)
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
late_fixup = Vector{Int}()
return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, 1, 1, 1)
new_new_nodes = NewNode[]
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, 1, 1)
end

# For inlining
Expand All @@ -437,8 +448,14 @@ mutable struct IncrementalCompact
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
used_ssas = fill(0, new_len)
late_fixup = Vector{Int}()
return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, parent.result_bbs,
ssa_rename, parent.used_ssas, late_fixup, perm, 1, 1, result_offset, parent.active_result_bb)
new_new_nodes = NewNode[]
pending_nodes = NewNode[]
pending_perm = Int[]
return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags,
parent.result_bbs, ssa_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb)
end
end

Expand All @@ -455,17 +472,88 @@ function getindex(compact::IncrementalCompact, idx::Int)
end
end

function getindex(compact::IncrementalCompact, ssa::SSAValue)
@assert ssa.id < compact.result_idx
return compact.result[ssa.id]
end

function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
id = ssa.id
if id <= length(compact.ir.stmts)
return compact.ir.stmts[id]
end
id -= length(compact.ir.stmts)
if id <= length(compact.ir.new_nodes)
return compact.ir.new_nodes[id].node
end
id -= length(compact.ir.new_nodes)
return compact.pending_nodes[id].node
end

function getindex(compact::IncrementalCompact, ssa::NewSSAValue)
return compact.new_new_nodes[ssa.id].node
end

function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
needs_late_fixup = isa(v, NewSSAValue)
if isa(v, SSAValue)
compact.used_ssas[v.id] += 1
else
for ops in userefs(v)
val = ops[]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
elseif isa(val, NewSSAValue)
needs_late_fixup = true
end
end
end
needs_late_fixup
end

function resort_pending!(compact)
sort!(compact.pending_perm, DEFAULT_STABLE, Order.By(x->compact.pending_nodes[x].pos))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this returning a value?


function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @nospecialize(val), reverse_affinity::Bool=false)
Copy link
Member

@vtjnash vtjnash May 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what type is before - looks like this should written to use dispatch?

if isa(before, SSAValue)
if before.id < compact.result_idx
count_added_node!(compact, val)
line = compact.result_lines[before.id]
push!(compact.new_new_nodes, NewNode(before.id, reverse_affinity, typ, val, line))
return NewSSAValue(length(compact.new_new_nodes))
else
line = compact.ir.lines[before.id]
push!(compact.pending_nodes, NewNode(before.id, reverse_affinity, typ, val, line))
push!(compact.pending_perm, length(compact.pending_nodes))
resort_pending!(compact)
os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes))
push!(compact.ssa_rename, os)
push!(compact.used_ssas, 0)
return os
end
elseif isa(before, OldSSAValue)
pos = before.id
if pos > length(compact.ir.stmts)
@assert reverse_affinity
entry = compact.pending_nodes[pos - length(compact.ir.stmts) - length(compact.ir.new_nodes)]
pos, reverse_affinity = entry.pos, entry.reverse_affinity
end
line = 0 #compact.ir.lines[before.id]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO?

push!(compact.pending_nodes, NewNode(pos, reverse_affinity, typ, val, line))
push!(compact.pending_perm, length(compact.pending_nodes))
resort_pending!(compact)
os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes))
push!(compact.ssa_rename, os)
push!(compact.used_ssas, 0)
return os
elseif isa(before, NewSSAValue)
before_entry = compact.new_new_nodes[before.id]
push!(compact.new_new_nodes, NewNode(before_entry.pos, reverse_affinity, typ, val, before_entry.line))
return NewSSAValue(length(compact.new_new_nodes))
else
error("Unsupported")
end
end

function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int)
Expand All @@ -484,23 +572,39 @@ function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nos
end

function getindex(view::TypesView, v::OldSSAValue)
return view.ir.ir.types[v.id]
id = v.id
if id <= length(view.ir.ir.types)
return view.ir.ir.types[id]
end
id -= length(view.ir.ir.types)
if id <= length(view.ir.ir.new_nodes)
return view.ir.ir.new_nodes[id].typ
end
id -= length(view.ir.ir.new_nodes)
return view.ir.pending_nodes[id].typ
end

function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue)
@assert idx.id < compact.result_idx
(compact.result[idx.id] === v) && return
# Kill count for current uses
for ops in userefs(compact.result[idx.id])
val = ops[]
if isa(val, SSAValue)
@assert compact.used_ssas[val.id] >= 1
compact.used_ssas[val.id] -= 1
end
end
compact.result[idx.id] = v
# Add count for new use
if count_added_node!(compact, v)
push!(compact.late_fixup, idx.id)
end
end

function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int)
if idx < compact.result_idx
(compact.result[idx] === v) && return
# Kill count for current uses
for ops in userefs(compact.result[idx])
val = ops[]
if isa(val, SSAValue)
@assert compact.used_ssas[val.id] >= 1
compact.used_ssas[val.id] -= 1
end
end
compact.result[idx] = v
# Add count for new use
count_added_node!(compact, v)
compact[SSAValue(idx)] = v
else
compact.ir.stmts[idx] = v
end
Expand All @@ -509,10 +613,14 @@ end

function getindex(view::TypesView, idx)
isa(idx, SSAValue) && (idx = idx.id)
ir = view.ir
if isa(ir, IncrementalCompact)
if idx < ir.result_idx
return ir.result_types[idx]
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
return view.ir.result_types[idx]
else
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
if idx <= length(ir.types)
return ir.types[idx]
else
return ir.new_nodes[idx - length(ir.types)].typ
end
ir = ir.ir
end
Expand All @@ -523,6 +631,12 @@ function getindex(view::TypesView, idx)
end
end

function getindex(view::TypesView, idx::NewSSAValue)
@assert isa(view.ir, IncrementalCompact)
compact = view.ir
compact.new_new_nodes[idx.id].typ
end

start(compact::IncrementalCompact) = (compact.idx, 1)
function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int})
return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm))
Expand Down Expand Up @@ -554,6 +668,8 @@ function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{
ssa_rename[idx] = SSAValue(result_idx)
if stmt === nothing
ssa_rename[idx] = stmt
elseif isa(stmt, OldSSAValue)
ssa_rename[idx] = ssa_rename[stmt.id]
elseif isa(stmt, GotoNode) || isa(stmt, GlobalRef)
result[result_idx] = stmt
result_idx += 1
Expand Down Expand Up @@ -652,6 +768,13 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int})
new_node_entry = compact.ir.new_nodes[new_idx]
new_idx += length(compact.ir.stmts)
return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb)
elseif !isempty(compact.pending_perm) &&
(entry = compact.pending_nodes[compact.pending_perm[1]];
entry.attach_after ? entry.pos == idx - 1 : entry.pos == idx)
new_idx = popfirst!(compact.pending_perm)
new_node_entry = compact.pending_nodes[new_idx]
new_idx += length(compact.ir.stmts) + length(compact.ir.new_nodes)
return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb)
end
# This will get overwritten in future iterations if
# result_idx is not, incremented, but that's ok and expected
Expand All @@ -673,10 +796,12 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int})
return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb)
end

function maybe_erase_unused!(extra_worklist, compact, idx)
effect_free = stmt_effect_free(compact.result[idx], compact, compact.ir.mod)
function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type signature?

stmt = compact.result[idx]
stmt === nothing && return false
effect_free = stmt_effect_free(stmt, compact, compact.ir.mod)
if effect_free
for ops in userefs(compact.result[idx])
for ops in userefs(stmt)
val = ops[]
if isa(val, SSAValue)
if compact.used_ssas[val.id] == 1
Expand All @@ -685,10 +810,13 @@ function maybe_erase_unused!(extra_worklist, compact, idx)
end
end
compact.used_ssas[val.id] -= 1
callback(val)
end
end
compact.result[idx] = nothing
return true
end
return false
end

function fixup_phinode_values!(compact, old_values)
Expand All @@ -701,47 +829,81 @@ function fixup_phinode_values!(compact, old_values)
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
end
elseif isa(val, NewSSAValue)
val = SSAValue(length(compact.result) + val.id)
end
values[i] = val
end
values
end

function fixup_node(compact, @nospecialize(stmt))
if isa(stmt, PhiNode)
return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
elseif isa(stmt, PhiCNode)
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
else
urs = userefs(stmt)
urs === () && return stmt
for ur in urs
val = ur[]
if isa(val, NewSSAValue)
ur[] = SSAValue(length(compact.result) + val.id)
end
end
return urs[]
end
end

function just_fixup!(compact)
for idx in compact.late_fixup
stmt = compact.result[idx]
if isa(stmt, PhiNode)
compact.result[idx] = PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
else
stmt = stmt::PhiCNode
compact.result[idx] = PhiCNode(fixup_phinode_values!(compact, stmt.values))
end
new_stmt = fixup_node(compact, stmt)
(stmt !== new_stmt) && (compact.result[idx] = new_stmt)
end
for idx in 1:length(compact.new_new_nodes)
node = compact.new_new_nodes[idx]
new_stmt = fixup_node(compact, node.node)
(node.node !== new_stmt) && (compact.new_new_nodes[idx] = NewNode(node, node=new_stmt))
end
end

function finish(compact::IncrementalCompact)
just_fixup!(compact)
# Record this somewhere?
result_idx = compact.result_idx
resize!(compact.result, result_idx-1)
resize!(compact.result_types, result_idx-1)
resize!(compact.result_lines, result_idx-1)
resize!(compact.result_flags, result_idx-1)
bb = compact.result_bbs[end]
compact.result_bbs[end] = BasicBlock(bb,
StmtRange(first(bb.stmts), result_idx-1))
function simple_dce!(compact)
# Perform simple DCE for unused values
extra_worklist = Int[]
for (idx, nused) in Iterators.enumerate(compact.used_ssas)
idx >= result_idx && break
idx >= compact.result_idx && break
nused == 0 || continue
maybe_erase_unused!(extra_worklist, compact, idx)
end
while !isempty(extra_worklist)
maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist))
end
end

function non_dce_finish!(compact::IncrementalCompact)
result_idx = compact.result_idx
resize!(compact.result, result_idx-1)
resize!(compact.result_types, result_idx-1)
resize!(compact.result_lines, result_idx-1)
resize!(compact.result_flags, result_idx-1)
just_fixup!(compact)
bb = compact.result_bbs[end]
compact.result_bbs[end] = BasicBlock(bb,
StmtRange(first(bb.stmts), result_idx-1))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the BasicBlock return value intended to be meaningful?


function finish(compact::IncrementalCompact)
non_dce_finish!(compact)
simple_dce!(compact)
complete(compact)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this return value intended to be meaningful?

end

function complete(compact)
cfg = CFG(compact.result_bbs, Int[first(bb.stmts) for bb in compact.result_bbs[2:end]])
return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, NewNode[])
return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes)
end

function compact!(code::IRCode)
Expand Down
Loading