Skip to content

Commit

Permalink
slot2ssa: Remove type iteration
Browse files Browse the repository at this point in the history
Now that we have the `bb_vartables` from Inference, we can get a
converged (and frequently more precise) result for our inserted PhiNodes
directly, instead of trying to compute the fix point in the optimizer.

This change also improves the way that we are inserting PiNodes to
ensure that they are present in every block that requires them for
precision, not just where we are already inserting PhiNodes. This is
necessary due to the presence of `Conditional`, which means we may
have a refined type in a block even when dominated by a definition
whose type was precise outside of the Conditional.
  • Loading branch information
topolarity committed Aug 17, 2023
1 parent 1d4da7e commit b83d84f
Showing 1 changed file with 42 additions and 115 deletions.
157 changes: 42 additions & 115 deletions base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ function fixemup!(@specialize(slot_filter), @specialize(rename_slot), ir::IRCode
end
op[] = x
elseif isa(val, GlobalRef) && !(isdefined(val.mod, val.name) && isconst(val.mod, val.name))
typ = typ_for_val(val, ci, ir.sptypes, idx, Any[])
typ = typ_for_val(val, ci, ir, idx, Any[])
new_inst = NewInstruction(val, typ)
op[] = NewSSAValue(insert_node!(ir, idx, new_inst).id - length(ir.stmts))
elseif isexpr(val, :static_parameter)
ty = typ_for_val(val, ci, ir.sptypes, idx, Any[])
ty = typ_for_val(val, ci, ir, idx, Any[])
if isa(ty, Const)
inst = NewInstruction(quoted(ty.val), ty)
else
Expand Down Expand Up @@ -208,26 +208,22 @@ function strip_trailing_junk!(ci::CodeInfo, cfg::CFG, code::Vector{Any}, info::V
nothing
end

struct DelayedTyp
phi::NewSSAValue
end

# maybe use expr_type?
function typ_for_val(@nospecialize(x), ci::CodeInfo, sptypes::Vector{VarState}, idx::Int, slottypes::Vector{Any})
function typ_for_val(@nospecialize(x), ci::CodeInfo, ir::IRCode, idx::Int, slottypes::Vector{Any})
if isa(x, Expr)
if x.head === :static_parameter
return sptypes[x.args[1]::Int].typ
return ir.sptypes[x.args[1]::Int].typ
elseif x.head === :boundscheck
return Bool
elseif x.head === :copyast
return typ_for_val(x.args[1], ci, sptypes, idx, slottypes)
return typ_for_val(x.args[1], ci, ir, idx, slottypes)
end
return (ci.ssavaluetypes::Vector{Any})[idx]
end
isa(x, GlobalRef) && return abstract_eval_globalref(x)
isa(x, SSAValue) && return (ci.ssavaluetypes::Vector{Any})[x.id]
isa(x, Argument) && return slottypes[x.n]
isa(x, NewSSAValue) && return DelayedTyp(x)
isa(x, NewSSAValue) && return types(ir)[new_to_regular(x, length(ir.stmts))]
isa(x, QuoteNode) && return Const(x.value)
isa(x, Union{Symbol, PiNode, PhiNode, SlotNumber}) && error("unexpected val type")
return Const(x)
Expand Down Expand Up @@ -566,39 +562,13 @@ function compute_live_ins(cfg::CFG, defs::Vector{Int}, uses::Vector{Int})
BlockLiveness(bb_defs, bb_uses)
end

function recompute_type(node::Union{PhiNode, PhiCNode, PiNode}, ci::CodeInfo, ir::IRCode,
sptypes::Vector{VarState}, slottypes::Vector{Any}, nstmts::Int, 𝕃ₒ::AbstractLattice)
if node isa PiNode
typ = typ_for_val(node.val, ci, sptypes, -1, slottypes)
while isa(typ, DelayedTyp)
typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
((typ isa Core.Const) || (typ isa Core.PartialStruct)) && return typ
return typeintersect(typ, widenconst(node.typ))
else
new_typ = Union{}
for i = 1:length(node.values)
if isa(node, PhiNode) && !isassigned(node.values, i)
continue
end
# TODO: I don't really understand this...
typ = typ_for_val(node.values[i], ci, sptypes, -1, slottypes)
while isa(typ, DelayedTyp)
typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
new_typ = tmerge(𝕃ₒ, new_typ, typ)
end
return new_typ
end
end

struct TryCatchRegion
enter_block::Int
leave_block::Int
end
struct NewSlotPhi{Phi}
ssaval::NewSSAValue
node::Union{Phi,PiNode}
node::Phi
undef_ssaval::Union{NewSSAValue, Nothing}
undef_node::Union{Phi, Nothing}
end
Expand Down Expand Up @@ -661,7 +631,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
fixup_uses!(ir, ci, code, slot.uses, idx, nothing)
else
val = code[slot.defs[]].args[2]
typ = typ_for_val(val, ci, ir.sptypes, slot.defs[], sv.slottypes)
typ = typ_for_val(val, ci, ir, slot.defs[], sv.slottypes)
ssaval = make_ssa!(ci, code, slot.defs[], typ)
fixup_uses!(ir, ci, code, slot.uses, idx, ssaval)
end
Expand Down Expand Up @@ -697,20 +667,12 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
for block in phiblocks
push!(phi_slots[block], idx)
node = PhiNode()
ssaval = NewSSAValue(insert_node!(ir,
first_insert_for_bb(code, cfg, block), NewInstruction(node, Union{})).id - length(ir.stmts))

# Insert PiNode for this incoming edge if the `bb_vartables`
# have extra information for us
typ = sv.slottypes[idx]
if sv.bb_vartables[block] !== nothing
typ = widenslotwrapper(ignorelimited(sv.bb_vartables[block][idx].typ))
if !βŠ‘(𝕃ₒ, sv.slottypes[idx], typ)
node = PiNode(ssaval, typ)
ssaval = NewSSAValue(insert_node!(ir,
first_insert_for_bb(code, cfg, block), NewInstruction(node, typ)).id - length(ir.stmts))
end
end

ssaval = NewSSAValue(insert_node!(ir,
first_insert_for_bb(code, cfg, block), NewInstruction(node, typ)).id - length(ir.stmts))
undef_node = undef_ssaval = nothing
if (ci.slotflags[idx] & SLOT_USEDUNDEF) != 0
undef_node = PhiNode()
Expand All @@ -732,7 +694,6 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
]
worklist = Tuple{Int, Int, Vector{Pair{Any, Any}}}[(1, 0, initial_incoming_vals)]
visited = BitSet()
type_refine_phi = BitSet()
new_nodes = ir.new_nodes
@timeit "SSA Rename" while !isempty(worklist)
(item::Int, pred, incoming_vals) = pop!(worklist)
Expand Down Expand Up @@ -761,11 +722,6 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
# Liveness analysis would probably have prevented us from inserting this phi node
continue
end
pi_node = pi_ssaval = nothing
if isa(node, PiNode)
pi_ssaval, pi_node = ssaval, node
ssaval, node = node.val, ir[new_to_regular(node.val::NewSSAValue, length(ir.stmts))][:stmt]::PhiNode
end
push!(node.edges, pred)
if incoming_val === UNDEF_TOKEN
resize!(node.values, length(node.values)+1)
Expand All @@ -778,31 +734,27 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
push!(undef_node.values, incoming_def)
outgoing_def = undef_ssaval
end
# TODO: Remove the next line, it shouldn't be necessary
push!(type_refine_phi, ssaval.id)
if isa(incoming_val, NewSSAValue)
push!(type_refine_phi, ssaval.id)
end
typ = incoming_val === UNDEF_TOKEN ? Union{} : typ_for_val(incoming_val, ci, ir.sptypes, -1, sv.slottypes)
old_entry = new_nodes.stmts[ssaval.id]
if isa(typ, DelayedTyp)
push!(type_refine_phi, ssaval.id)
end
new_typ = isa(typ, DelayedTyp) ? Union{} : tmerge(𝕃ₒ, old_entry[:type], typ)
old_entry[:type] = new_typ
old_entry[:stmt] = node

if pi_node !== nothing
if !(new_typ isa Core.Const) && !(new_typ isa Core.PartialStruct)
new_nodes.stmts[pi_ssaval.id][:type] = typeintersect(new_typ, widenconst(pi_node.typ))
incoming_vals[slot] = Pair{Any, Any}(ssaval, outgoing_def)
end
(item in visited) && continue
# Record Pi nodes if necessary
if sv.bb_vartables[item] !== nothing
for slot in 1:length(sv.slottypes)
(ival, idef) = incoming_vals[slot]
(ival === SSAValue(-1)) && continue
(ival === SSAValue(-2)) && continue

varstate = sv.bb_vartables[item][slot]
typ = widenslotwrapper(ignorelimited(varstate.typ))
if !βŠ‘(𝕃ₒ, sv.slottypes[slot], typ)
node = PiNode(ival, typ)
ival = NewSSAValue(insert_node!(ir,
first_insert_for_bb(code, cfg, item), NewInstruction(node, typ)).id - length(ir.stmts))
incoming_vals[slot] = Pair{Any, Any}(ival, idef)
end
push!(type_refine_phi, pi_ssaval.id)
incoming_vals[slot] = Pair{Any, Any}(pi_ssaval, outgoing_def)
else
incoming_vals[slot] = Pair{Any, Any}(ssaval, outgoing_def)
end
end
(item in visited) && continue
# Record phi_C nodes if necessary
if haskey(new_phic_nodes, item)
for (; slot, insert) in new_phic_nodes[item]
Expand All @@ -818,7 +770,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
(ival, idef) = incoming_vals[slot_id(slot)]
ivalundef = ival === UNDEF_TOKEN
Ξ₯ = NewInstruction(ivalundef ? UpsilonNode() : UpsilonNode(ival),
ivalundef ? Union{} : typ_for_val(ival, ci, ir.sptypes, -1, sv.slottypes))
ivalundef ? Union{} : typ_for_val(ival, ci, ir, -1, sv.slottypes))
insertpos = first_insert_for_bb(code, cfg, item)
# insert `UpsilonNode` immediately before the `:enter` expression
Ξ₯ssa = insert_node!(ir, insertpos, Ξ₯)
Expand Down Expand Up @@ -850,7 +802,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
if isa(arg1, SlotNumber)
id = slot_id(arg1)
val = stmt.args[2]
typ = typ_for_val(val, ci, ir.sptypes, idx, sv.slottypes)
typ = typ_for_val(val, ci, ir, idx, sv.slottypes)
# Having UNDEF_TOKEN appear on the RHS is possible if we're on a dead branch.
# Do something reasonable here, by marking the LHS as undef as well.
if val !== UNDEF_TOKEN
Expand Down Expand Up @@ -885,6 +837,16 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
end
end
end
# Unwrap any PiNodes before continuing, since they weren't considered during our
# dominance frontier calculation and so have to be used locally in each BB.
for (i, (ival, idef)) in enumerate(incoming_vals)
if ival isa NewSSAValue
stmt = ir[new_to_regular(ival::NewSSAValue, length(ir.stmts))][:stmt]
if stmt isa PiNode
incoming_vals[i] = Pair{Any, Any}(stmt.val, idef)
end
end
end
for succ in cfg.blocks[item].succs
push!(worklist, (succ, item, copy(incoming_vals)))
end
Expand All @@ -904,7 +866,6 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
nstmts = length(ir.stmts)
new_code = Vector{Any}(undef, nstmts)
ssavalmap = fill(SSAValue(-1), length(ssavaluetypes) + 1)
result_types = Any[Any for _ in 1:nstmts]
# Detect statement positions for assignments and construct array
for (bb, idx) in bbidxiter(ir)
stmt = code[idx]
Expand All @@ -927,7 +888,6 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
new_code[idx] = stmt
else
ssavalmap[idx] = SSAValue(idx)
result_types[idx] = ssavaluetypes[idx]
if isa(stmt, PhiNode)
edges = Int32[edge == 0 ? 0 : block_for_inst(cfg, Int(edge)) for edge in stmt.edges]
new_code[idx] = PhiNode(edges, stmt.values)
Expand All @@ -940,56 +900,23 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, sv::OptimizationState,
for (; insert) in nodes
(; node, ssaval) = insert
new_typ = Union{}
# TODO: This could just be the ones that depend on other phis
push!(type_refine_phi, ssaval.id)
new_idx = ssaval.id
node = new_nodes.stmts[new_idx]
phic_values = (node[:stmt]::PhiCNode).values
for i = 1:length(phic_values)
orig_typ = typ = typ_for_val(phic_values[i], ci, ir.sptypes, -1, sv.slottypes)
while isa(typ, DelayedTyp)
typ = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
typ = typ_for_val(phic_values[i], ci, ir, -1, sv.slottypes)
new_typ = tmerge(𝕃ₒ, new_typ, typ)
end
node[:type] = new_typ
end
end
# This is a bit awkward, because it basically duplicates what type
# inference does. Ideally, we'd just use this representation earlier
# to make sure phi nodes have accurate types
changed = true
while changed
changed = false
for new_idx in type_refine_phi
node = new_nodes.stmts[new_idx]
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode,PiNode}, ci, ir, ir.sptypes, sv.slottypes, nstmts, 𝕃ₒ)
if !βŠ‘(𝕃ₒ, node[:type], new_typ) || !βŠ‘(𝕃ₒ, new_typ, node[:type])
node[:type] = new_typ
changed = true
end
end
end
for i in 1:length(result_types)
rt_i = result_types[i]
if rt_i isa DelayedTyp
result_types[i] = types(ir)[new_to_regular(rt_i.phi::NewSSAValue, nstmts)]
end
end
for i = 1:length(new_nodes)
local node = new_nodes.stmts[i]
local typ = node[:type]
if isa(typ, DelayedTyp)
node[:type] = types(ir)[new_to_regular(typ.phi::NewSSAValue, nstmts)]
end
end
# Renumber SSA values
@assert isempty(ir.stmts.type)
resize!(ir.stmts.type, nstmts)
for i in 1:nstmts
local node = ir.stmts[i]
node[:stmt] = new_to_regular(renumber_ssa!(new_code[i], ssavalmap), nstmts)
node[:type] = result_types[i]
node[:type] = ssavaluetypes[i]
end
for i = 1:length(new_nodes)
local node = new_nodes.stmts[i]
Expand Down

0 comments on commit b83d84f

Please sign in to comment.