Skip to content

Commit

Permalink
Fix CFG corruption in CFG simplify (JuliaLang#48962)
Browse files Browse the repository at this point in the history
IncrementalCompact ordinarily takes ownership of the CFG in order to to its
transform. cfg_simplify! separate constructs the CFG transform structures
ahead of time and was assuming this meant that the original CFG remained
untouched (since it was using it for lookup operations). Unfortunately,
the IncrementalCompact constructor was already doing some CFG manipulation
cuasing the CFG to be corrupted and cfg_simplify! to create invalid IR.
Fix that by refactoring the IncrementalCompact constructor to allow
passing in the CFG transformation state explicitly, rather than poking
it into the fields afterwards.
  • Loading branch information
Keno authored and Xnartharax committed Apr 13, 2023
1 parent ef09cb8 commit 20aff5e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 89 deletions.
7 changes: 3 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
case = case::ConstantCase
val = case.val
end
if !isempty(compact.result_bbs[bb].preds)
if !isempty(compact.cfg_transform.result_bbs[bb].preds)
push!(pn.edges, bb)
push!(pn.values, val)
insert_node_here!(compact,
Expand Down Expand Up @@ -648,8 +648,7 @@ function batch_inline!(ir::IRCode, todo::Vector{Pair{Int,Any}}, propagate_inboun
boundscheck = :propagate
end

let compact = IncrementalCompact(ir, false)
compact.result_bbs = state.new_cfg_blocks
let compact = IncrementalCompact(ir, CFGTransformState!(state.new_cfg_blocks, false))
# This needs to be a minimum and is more of a size hint
nn = 0
for (_, item) in todo
Expand All @@ -670,7 +669,7 @@ function batch_inline!(ir::IRCode, todo::Vector{Pair{Int,Any}}, propagate_inboun
argexprs = copy(stmt.args)
end
refinish = false
if compact.result_idx == first(compact.result_bbs[compact.active_result_bb].stmts)
if compact.result_idx == first(compact.cfg_transform.result_bbs[compact.active_result_bb].stmts)
compact.active_result_bb -= 1
refinish = true
end
Expand Down
174 changes: 96 additions & 78 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,60 @@ end
insert_node!(ir::IRCode, pos::Int, newinst::NewInstruction, attach_after::Bool=false) =
insert_node!(ir, SSAValue(pos), newinst, attach_after)

struct CFGTransformState
cfg_transforms_enabled::Bool
fold_constant_branches::Bool
result_bbs::Vector{BasicBlock}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}
end

# N.B.: Takes ownership of the CFG array
function CFGTransformState!(blocks::Vector{BasicBlock}, allow_cfg_transforms::Bool=false)
if allow_cfg_transforms
bb_rename = Vector{Int}(undef, length(blocks))
cur_bb = 1
domtree = construct_domtree(blocks)
for i = 1:length(bb_rename)
if bb_unreachable(domtree, i)
bb_rename[i] = -1
else
bb_rename[i] = cur_bb
cur_bb += 1
end
end
for i = 1:length(bb_rename)
bb_rename[i] == -1 && continue
preds, succs = blocks[i].preds, blocks[i].succs
# Rename preds
for j = 1:length(preds)
if preds[j] != 0
preds[j] = bb_rename[preds[j]]
end
end
# Dead blocks get removed from the predecessor list
filter!(x->x !== -1, preds)
# Rename succs
for j = 1:length(succs)
succs[j] = bb_rename[succs[j]]
end
end
let blocks = blocks, bb_rename = bb_rename
result_bbs = BasicBlock[blocks[i] for i = 1:length(blocks) if bb_rename[i] != -1]
end
else
bb_rename = Vector{Int}()
result_bbs = blocks
end
return CFGTransformState(allow_cfg_transforms, allow_cfg_transforms, result_bbs, bb_rename, bb_rename)
end

mutable struct IncrementalCompact
ir::IRCode
result::InstructionStream
result_bbs::Vector{BasicBlock}

cfg_transform::CFGTransformState
ssa_rename::Vector{Any}
bb_rename_pred::Vector{Int}
bb_rename_succ::Vector{Int}

used_ssas::Vector{Int}
late_fixup::Vector{Int}
Expand All @@ -587,10 +633,8 @@ mutable struct IncrementalCompact
active_bb::Int
active_result_bb::Int
renamed_new_nodes::Bool
cfg_transforms_enabled::Bool
fold_constant_branches::Bool

function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false)
function IncrementalCompact(code::IRCode, cfg_transform::CFGTransformState)
# Sort by position with attach after nodes after regular ones
info = code.new_nodes.info
perm = sort!(collect(eachindex(info)); by=i->(2info[i].pos+info[i].attach_after, i))
Expand All @@ -599,49 +643,14 @@ mutable struct IncrementalCompact
used_ssas = fill(0, new_len)
new_new_used_ssas = Vector{Int}()
blocks = code.cfg.blocks
if allow_cfg_transforms
bb_rename = Vector{Int}(undef, length(blocks))
cur_bb = 1
domtree = construct_domtree(blocks)
for i = 1:length(bb_rename)
if bb_unreachable(domtree, i)
bb_rename[i] = -1
else
bb_rename[i] = cur_bb
cur_bb += 1
end
end
for i = 1:length(bb_rename)
bb_rename[i] == -1 && continue
preds, succs = blocks[i].preds, blocks[i].succs
# Rename preds
for j = 1:length(preds)
if preds[j] != 0
preds[j] = bb_rename[preds[j]]
end
end
# Dead blocks get removed from the predecessor list
filter!(x->x !== -1, preds)
# Rename succs
for j = 1:length(succs)
succs[j] = bb_rename[succs[j]]
end
end
let blocks = blocks, bb_rename = bb_rename
result_bbs = BasicBlock[blocks[i] for i = 1:length(blocks) if bb_rename[i] != -1]
end
else
bb_rename = Vector{Int}()
result_bbs = code.cfg.blocks
end
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
late_fixup = Vector{Int}()
new_new_nodes = NewNodeStream()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, result, result_bbs, ssa_rename, bb_rename, bb_rename, used_ssas, late_fixup, perm, 1,
return new(code, result, cfg_transform, ssa_rename, used_ssas, late_fixup, perm, 1,
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
1, 1, 1, 1, false, allow_cfg_transforms, allow_cfg_transforms)
1, 1, 1, 1, false)
end

# For inlining
Expand All @@ -653,14 +662,18 @@ mutable struct IncrementalCompact
bb_rename = Vector{Int}()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, parent.result,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
return new(code, parent.result, CFGTransformState(false, false, parent.cfg_transform.result_bbs, bb_rename, bb_rename),
ssa_rename, parent.used_ssas,
parent.late_fixup, perm, 1,
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
1, result_offset, 1, parent.active_result_bb, false, false, false)
1, result_offset, 1, parent.active_result_bb, false)
end
end

function IncrementalCompact(code::IRCode, allow_cfg_transforms::Bool=false)
return IncrementalCompact(code, CFGTransformState!(code.cfg.blocks, allow_cfg_transforms))
end

struct TypesView{T}
ir::T # ::Union{IRCode, IncrementalCompact}
end
Expand Down Expand Up @@ -698,7 +711,7 @@ end
function block_for_inst(compact::IncrementalCompact, idx::SSAValue)
id = idx.id
if id < compact.result_idx # if ssa within result
return searchsortedfirst(compact.result_bbs, BasicBlock(StmtRange(id, id)),
return searchsortedfirst(compact.cfg_transform.result_bbs, BasicBlock(StmtRange(id, id)),
1, compact.active_result_bb, bb_ordering())-1
else
return block_for_inst(compact.ir.cfg, id)
Expand Down Expand Up @@ -883,9 +896,10 @@ function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction,
newline = newinst.line::Int32
refinish = false
result_idx = compact.result_idx
result_bbs = compact.cfg_transform.result_bbs
if reverse_affinity &&
((compact.active_result_bb == length(compact.result_bbs) + 1) ||
result_idx == first(compact.result_bbs[compact.active_result_bb].stmts))
((compact.active_result_bb == length(result_bbs) + 1) ||
result_idx == first(result_bbs[compact.active_result_bb].stmts))
compact.active_result_bb -= 1
refinish = true
end
Expand Down Expand Up @@ -1173,18 +1187,19 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
# Note: We recursively kill as many edges as are obviously dead. However, this
# may leave dead loops in the IR. We kill these later in a CFG cleanup pass (or
# worstcase during codegen).
preds = compact.result_bbs[compact.bb_rename_succ[to]].preds
succs = compact.result_bbs[compact.bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x->x === compact.bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x->x === compact.bb_rename_succ[to], succs)::Int)
(; bb_rename_pred, bb_rename_succ, result_bbs) = compact.cfg_transform
preds = result_bbs[bb_rename_succ[to]].preds
succs = result_bbs[bb_rename_pred[from]].succs
deleteat!(preds, findfirst(x->x === bb_rename_pred[from], preds)::Int)
deleteat!(succs, findfirst(x->x === bb_rename_succ[to], succs)::Int)
# Check if the block is now dead
if length(preds) == 0
for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)::Int)
for succ in copy(result_bbs[bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, bb_rename_pred)::Int)
end
if to < active_bb
# Kill all statements in the block
stmts = compact.result_bbs[compact.bb_rename_succ[to]].stmts
stmts = result_bbs[bb_rename_succ[to]].stmts
for stmt in stmts
compact.result[stmt][:inst] = nothing
end
Expand All @@ -1194,20 +1209,20 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
# indicates that the block is not to be scheduled, but there should
# still be an (unreachable) BB inserted into the final IR to avoid
# disturbing the BB numbering.
compact.bb_rename_succ[to] = -2
bb_rename_succ[to] = -2
end
else
# Remove this edge from all phi nodes in `to` block
# NOTE: It is possible for `to` to contain only `nothing` statements,
# so we must be careful to stop at its last statement
if to < active_bb
stmts = compact.result_bbs[compact.bb_rename_succ[to]].stmts
stmts = result_bbs[bb_rename_succ[to]].stmts
idx = first(stmts)
while idx <= last(stmts)
stmt = compact.result[idx][:inst]
stmt === nothing && continue
isa(stmt, PhiNode) || break
i = findfirst(x-> x == compact.bb_rename_pred[from], stmt.edges)
i = findfirst(x-> x == bb_rename_pred[from], stmt.edges)
if i !== nothing
deleteat!(stmt.edges, i)
deleteat!(stmt.values, i)
Expand All @@ -1232,14 +1247,15 @@ end

function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instruction, idx::Int, processed_idx::Int, active_bb::Int, do_rename_ssa::Bool)
stmt = inst[:inst]
(; result, ssa_rename, late_fixup, used_ssas, new_new_used_ssas, cfg_transforms_enabled, fold_constant_branches) = compact
(; result, ssa_rename, late_fixup, used_ssas, new_new_used_ssas) = compact
(; cfg_transforms_enabled, fold_constant_branches, bb_rename_succ, bb_rename_pred, result_bbs) = compact.cfg_transform
ssa_rename[idx] = SSAValue(result_idx)
if stmt === nothing
ssa_rename[idx] = stmt
elseif isa(stmt, OldSSAValue)
ssa_rename[idx] = ssa_rename[stmt.id]
elseif isa(stmt, GotoNode) && cfg_transforms_enabled
label = compact.bb_rename_succ[stmt.label]
label = bb_rename_succ[stmt.label]
@assert label > 0
result[result_idx][:inst] = GotoNode(label)
result_idx += 1
Expand Down Expand Up @@ -1272,23 +1288,23 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
kill_edge!(compact, active_bb, active_bb, stmt.dest)
# Don't increment result_idx => Drop this statement
else
label = compact.bb_rename_succ[stmt.dest]
label = bb_rename_succ[stmt.dest]
@assert label > 0
result[result_idx][:inst] = GotoNode(label)
kill_edge!(compact, active_bb, active_bb, active_bb+1)
result_idx += 1
end
else
@label bail
label = compact.bb_rename_succ[stmt.dest]
label = bb_rename_succ[stmt.dest]
@assert label > 0
result[result_idx][:inst] = GotoIfNot(cond, label)
result_idx += 1
end
elseif isa(stmt, Expr)
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa)::Expr
if cfg_transforms_enabled && isexpr(stmt, :enter)
label = compact.bb_rename_succ[stmt.args[1]::Int]
label = bb_rename_succ[stmt.args[1]::Int]
@assert label > 0
stmt.args[1] = label
elseif isexpr(stmt, :throw_undef_if_not)
Expand Down Expand Up @@ -1333,7 +1349,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
elseif isa(stmt, PhiNode)
if cfg_transforms_enabled
# Rename phi node edges
map!(i -> compact.bb_rename_pred[i], stmt.edges, stmt.edges)
map!(i -> bb_rename_pred[i], stmt.edges, stmt.edges)

# Remove edges and values associated with dead blocks. Entries in
# `values` can be undefined when the phi node refers to something
Expand Down Expand Up @@ -1375,7 +1391,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
before_def = isassigned(values, 1) && (v = values[1]; isa(v, OldSSAValue)) && idx < v.id
if length(edges) == 1 && isassigned(values, 1) && !before_def &&
length(cfg_transforms_enabled ?
compact.result_bbs[compact.bb_rename_succ[active_bb]].preds :
result_bbs[bb_rename_succ[active_bb]].preds :
compact.ir.cfg.blocks[active_bb].preds) == 1
# There's only one predecessor left - just replace it
v = values[1]
Expand Down Expand Up @@ -1417,15 +1433,16 @@ function resize!(compact::IncrementalCompact, nnewnodes)
end

function finish_current_bb!(compact::IncrementalCompact, active_bb, old_result_idx=compact.result_idx, unreachable=false)
if compact.active_result_bb > length(compact.result_bbs)
(;result_bbs, cfg_transforms_enabled, bb_rename_succ) = compact.cfg_transform
if compact.active_result_bb > length(result_bbs)
#@assert compact.bb_rename[active_bb] == -1
return true
end
bb = compact.result_bbs[compact.active_result_bb]
bb = result_bbs[compact.active_result_bb]
# If this was the last statement in the BB and we decided to skip it, insert a
# dummy `nothing` node, to prevent changing the structure of the CFG
skipped = false
if !compact.cfg_transforms_enabled || active_bb == 0 || active_bb > length(compact.bb_rename_succ) || compact.bb_rename_succ[active_bb] != -1
if !cfg_transforms_enabled || active_bb == 0 || active_bb > length(bb_rename_succ) || bb_rename_succ[active_bb] != -1
if compact.result_idx == first(bb.stmts)
length(compact.result) < old_result_idx && resize!(compact, old_result_idx)
node = compact.result[old_result_idx]
Expand All @@ -1435,17 +1452,17 @@ function finish_current_bb!(compact::IncrementalCompact, active_bb, old_result_i
node[:inst], node[:type], node[:line] = nothing, Nothing, 0
end
compact.result_idx = old_result_idx + 1
elseif compact.cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts)
elseif cfg_transforms_enabled && compact.result_idx - 1 == first(bb.stmts)
# Optimization: If this BB consists of only a branch, eliminate this bb
end
compact.result_bbs[compact.active_result_bb] = BasicBlock(bb, StmtRange(first(bb.stmts), compact.result_idx-1))
result_bbs[compact.active_result_bb] = BasicBlock(bb, StmtRange(first(bb.stmts), compact.result_idx-1))
compact.active_result_bb += 1
else
skipped = true
end
if compact.active_result_bb <= length(compact.result_bbs)
new_bb = compact.result_bbs[compact.active_result_bb]
compact.result_bbs[compact.active_result_bb] = BasicBlock(new_bb,
if compact.active_result_bb <= length(result_bbs)
new_bb = result_bbs[compact.active_result_bb]
result_bbs[compact.active_result_bb] = BasicBlock(new_bb,
StmtRange(compact.result_idx, last(new_bb.stmts)))
end
return skipped
Expand Down Expand Up @@ -1537,7 +1554,8 @@ function iterate_compact(compact::IncrementalCompact)
resize!(compact, old_result_idx)
end
bb = compact.ir.cfg.blocks[active_bb]
if compact.cfg_transforms_enabled && active_bb > 1 && active_bb <= length(compact.bb_rename_succ) && compact.bb_rename_succ[active_bb] <= -1
(; cfg_transforms_enabled, bb_rename_succ) = compact.cfg_transform
if cfg_transforms_enabled && active_bb > 1 && active_bb <= length(bb_rename_succ) && bb_rename_succ[active_bb] <= -1
# Dead block, so kill the entire block.
compact.idx = last(bb.stmts)
# Pop any remaining insertion nodes
Expand Down Expand Up @@ -1739,8 +1757,8 @@ function non_dce_finish!(compact::IncrementalCompact)
result_idx = compact.result_idx
resize!(compact.result, result_idx - 1)
just_fixup!(compact)
bb = compact.result_bbs[end]
compact.result_bbs[end] = BasicBlock(bb,
bb = compact.cfg_transform.result_bbs[end]
compact.cfg_transform.result_bbs[end] = BasicBlock(bb,
StmtRange(first(bb.stmts), result_idx-1))
compact.renamed_new_nodes = true
nothing
Expand All @@ -1753,7 +1771,7 @@ function finish(compact::IncrementalCompact)
end

function complete(compact::IncrementalCompact)
result_bbs = resize!(compact.result_bbs, compact.active_result_bb-1)
result_bbs = resize!(compact.cfg_transform.result_bbs, compact.active_result_bb-1)
cfg = CFG(result_bbs, Int[first(result_bbs[i].stmts) for i in 2:length(result_bbs)])
if should_check_ssa_counts()
oracle_check(compact)
Expand Down
Loading

0 comments on commit 20aff5e

Please sign in to comment.