diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 18dfb82161653..119801763e243 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -17,7 +17,7 @@ include("compiler/ssair/passes.jl") include("compiler/ssair/inlining2.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -@isdefined(Base) && include("compiler/ssair/show.jl") +#@isdefined(Base) && include("compiler/ssair/show.jl") function normalize_expr(stmt::Expr) if stmt.head === :gotoifnot @@ -165,8 +165,9 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode}, s @timeit "Inlining" ir = ssa_inlining_pass!(ir, linetable, sv) #@timeit "verify 2" verify_ir(ir) @timeit "domtree 2" domtree = construct_domtree(ir.cfg) + ir = compact!(ir) @timeit "SROA" ir = getfield_elim_pass!(ir, domtree) - @timeit "compact 2" ir = compact!(ir) + ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) #@Base.show ir diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 97a202855a1f6..13ce7737b2830 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -413,6 +413,12 @@ mutable struct IncrementalCompact # This could be Stateful, but bootstrapping doesn't like that perm::Vector{Int} new_nodes_idx::Int + # This supports insertion while compacting + new_new_nodes::Vector{NewNode} # New nodes that were before the compaction point at insertion time + # TODO: Switch these two to a min-heap of some sort + pending_nodes::Vector{NewNode} # New nodes that were after the compaction point at insertion time + pending_perm::Vector{Int} + # State idx::Int result_idx::Int active_result_bb::Int @@ -427,7 +433,12 @@ mutable struct IncrementalCompact used_ssas = fill(0, new_len) ssa_rename = Any[SSAValue(i) for i = 1:new_len] late_fixup = Vector{Int}() - return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, 1, 1, 1) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, result, result_types, result_lines, result_flags, code.cfg.blocks, ssa_rename, used_ssas, late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, 1, 1) end # For inlining @@ -437,8 +448,14 @@ mutable struct IncrementalCompact ssa_rename = Any[SSAValue(i) for i = 1:new_len] used_ssas = fill(0, new_len) late_fixup = Vector{Int}() - return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, parent.result_bbs, - ssa_rename, parent.used_ssas, late_fixup, perm, 1, 1, result_offset, parent.active_result_bb) + new_new_nodes = NewNode[] + pending_nodes = NewNode[] + pending_perm = Int[] + return new(code, parent.result, parent.result_types, parent.result_lines, parent.result_flags, + parent.result_bbs, ssa_rename, parent.used_ssas, + late_fixup, perm, 1, + new_new_nodes, pending_nodes, pending_perm, + 1, result_offset, parent.active_result_bb) end end @@ -455,7 +472,30 @@ function getindex(compact::IncrementalCompact, idx::Int) end end +function getindex(compact::IncrementalCompact, ssa::SSAValue) + @assert ssa.id < compact.result_idx + return compact.result[ssa.id] +end + +function getindex(compact::IncrementalCompact, ssa::OldSSAValue) + id = ssa.id + if id <= length(compact.ir.stmts) + return compact.ir.stmts[id] + end + id -= length(compact.ir.stmts) + if id <= length(compact.ir.new_nodes) + return compact.ir.new_nodes[id].node + end + id -= length(compact.ir.new_nodes) + return compact.pending_nodes[id].node +end + +function getindex(compact::IncrementalCompact, ssa::NewSSAValue) + return compact.new_new_nodes[ssa.id].node +end + function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) + needs_late_fixup = isa(v, NewSSAValue) if isa(v, SSAValue) compact.used_ssas[v.id] += 1 else @@ -463,9 +503,57 @@ function count_added_node!(compact::IncrementalCompact, @nospecialize(v)) val = ops[] if isa(val, SSAValue) compact.used_ssas[val.id] += 1 + elseif isa(val, NewSSAValue) + needs_late_fixup = true end end end + needs_late_fixup +end + +function resort_pending!(compact) + sort!(compact.pending_perm, DEFAULT_STABLE, Order.By(x->compact.pending_nodes[x].pos)) +end + +function insert_node!(compact::IncrementalCompact, before, @nospecialize(typ), @nospecialize(val), reverse_affinity::Bool=false) + if isa(before, SSAValue) + if before.id < compact.result_idx + count_added_node!(compact, val) + line = compact.result_lines[before.id] + push!(compact.new_new_nodes, NewNode(before.id, reverse_affinity, typ, val, line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + line = compact.ir.lines[before.id] + push!(compact.pending_nodes, NewNode(before.id, reverse_affinity, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + end + elseif isa(before, OldSSAValue) + pos = before.id + if pos > length(compact.ir.stmts) + @assert reverse_affinity + entry = compact.pending_nodes[pos - length(compact.ir.stmts) - length(compact.ir.new_nodes)] + pos, reverse_affinity = entry.pos, entry.reverse_affinity + end + line = 0 #compact.ir.lines[before.id] + push!(compact.pending_nodes, NewNode(pos, reverse_affinity, typ, val, line)) + push!(compact.pending_perm, length(compact.pending_nodes)) + resort_pending!(compact) + os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes)) + push!(compact.ssa_rename, os) + push!(compact.used_ssas, 0) + return os + elseif isa(before, NewSSAValue) + before_entry = compact.new_new_nodes[before.id] + push!(compact.new_new_nodes, NewNode(before_entry.pos, reverse_affinity, typ, val, before_entry.line)) + return NewSSAValue(length(compact.new_new_nodes)) + else + error("Unsupported") + end end function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typ), ltable_idx::Int) @@ -484,23 +572,39 @@ function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nos end function getindex(view::TypesView, v::OldSSAValue) - return view.ir.ir.types[v.id] + id = v.id + if id <= length(view.ir.ir.types) + return view.ir.ir.types[id] + end + id -= length(view.ir.ir.types) + if id <= length(view.ir.ir.new_nodes) + return view.ir.ir.new_nodes[id].typ + end + id -= length(view.ir.ir.new_nodes) + return view.ir.pending_nodes[id].typ +end + +function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) + @assert idx.id < compact.result_idx + (compact.result[idx.id] === v) && return + # Kill count for current uses + for ops in userefs(compact.result[idx.id]) + val = ops[] + if isa(val, SSAValue) + @assert compact.used_ssas[val.id] >= 1 + compact.used_ssas[val.id] -= 1 + end + end + compact.result[idx.id] = v + # Add count for new use + if count_added_node!(compact, v) + push!(compact.late_fixup, idx.id) + end end function setindex!(compact::IncrementalCompact, @nospecialize(v), idx::Int) if idx < compact.result_idx - (compact.result[idx] === v) && return - # Kill count for current uses - for ops in userefs(compact.result[idx]) - val = ops[] - if isa(val, SSAValue) - @assert compact.used_ssas[val.id] >= 1 - compact.used_ssas[val.id] -= 1 - end - end - compact.result[idx] = v - # Add count for new use - count_added_node!(compact, v) + compact[SSAValue(idx)] = v else compact.ir.stmts[idx] = v end @@ -509,10 +613,14 @@ end function getindex(view::TypesView, idx) isa(idx, SSAValue) && (idx = idx.id) - ir = view.ir - if isa(ir, IncrementalCompact) - if idx < ir.result_idx - return ir.result_types[idx] + if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx + return view.ir.result_types[idx] + else + ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir + if idx <= length(ir.types) + return ir.types[idx] + else + return ir.new_nodes[idx - length(ir.types)].typ end ir = ir.ir end @@ -523,6 +631,12 @@ function getindex(view::TypesView, idx) end end +function getindex(view::TypesView, idx::NewSSAValue) + @assert isa(view.ir, IncrementalCompact) + compact = view.ir + compact.new_new_nodes[idx.id].typ +end + start(compact::IncrementalCompact) = (compact.idx, 1) function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int}) return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm)) @@ -554,6 +668,8 @@ function process_node!(result::Vector{Any}, result_idx::Int, ssa_rename::Vector{ ssa_rename[idx] = SSAValue(result_idx) if stmt === nothing ssa_rename[idx] = stmt + elseif isa(stmt, OldSSAValue) + ssa_rename[idx] = ssa_rename[stmt.id] elseif isa(stmt, GotoNode) || isa(stmt, GlobalRef) result[result_idx] = stmt result_idx += 1 @@ -652,6 +768,13 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) new_node_entry = compact.ir.new_nodes[new_idx] new_idx += length(compact.ir.stmts) return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) + elseif !isempty(compact.pending_perm) && + (entry = compact.pending_nodes[compact.pending_perm[1]]; + entry.attach_after ? entry.pos == idx - 1 : entry.pos == idx) + new_idx = popfirst!(compact.pending_perm) + new_node_entry = compact.pending_nodes[new_idx] + new_idx += length(compact.ir.stmts) + length(compact.ir.new_nodes) + return process_newnode!(compact, new_idx, new_node_entry, idx, active_bb) end # This will get overwritten in future iterations if # result_idx is not, incremented, but that's ok and expected @@ -673,10 +796,12 @@ function next(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}) return Pair{Int, Any}(old_result_idx, compact.result[old_result_idx]), (compact.idx, active_bb) end -function maybe_erase_unused!(extra_worklist, compact, idx) - effect_free = stmt_effect_free(compact.result[idx], compact, compact.ir.mod) +function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing) + stmt = compact.result[idx] + stmt === nothing && return false + effect_free = stmt_effect_free(stmt, compact, compact.ir.mod) if effect_free - for ops in userefs(compact.result[idx]) + for ops in userefs(stmt) val = ops[] if isa(val, SSAValue) if compact.used_ssas[val.id] == 1 @@ -685,10 +810,13 @@ function maybe_erase_unused!(extra_worklist, compact, idx) end end compact.used_ssas[val.id] -= 1 + callback(val) end end compact.result[idx] = nothing + return true end + return false end function fixup_phinode_values!(compact, old_values) @@ -701,47 +829,81 @@ function fixup_phinode_values!(compact, old_values) if isa(val, SSAValue) compact.used_ssas[val.id] += 1 end + elseif isa(val, NewSSAValue) + val = SSAValue(length(compact.result) + val.id) end values[i] = val end values end +function fixup_node(compact, @nospecialize(stmt)) + if isa(stmt, PhiNode) + return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, PhiCNode) + return PhiCNode(fixup_phinode_values!(compact, stmt.values)) + elseif isa(stmt, NewSSAValue) + return SSAValue(length(compact.result) + stmt.id) + else + urs = userefs(stmt) + urs === () && return stmt + for ur in urs + val = ur[] + if isa(val, NewSSAValue) + ur[] = SSAValue(length(compact.result) + val.id) + end + end + return urs[] + end +end + function just_fixup!(compact) for idx in compact.late_fixup stmt = compact.result[idx] - if isa(stmt, PhiNode) - compact.result[idx] = PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values)) - else - stmt = stmt::PhiCNode - compact.result[idx] = PhiCNode(fixup_phinode_values!(compact, stmt.values)) - end + new_stmt = fixup_node(compact, stmt) + (stmt !== new_stmt) && (compact.result[idx] = new_stmt) + end + for idx in 1:length(compact.new_new_nodes) + node = compact.new_new_nodes[idx] + new_stmt = fixup_node(compact, node.node) + (node.node !== new_stmt) && (compact.new_new_nodes[idx] = NewNode(node, node=new_stmt)) end end -function finish(compact::IncrementalCompact) - just_fixup!(compact) - # Record this somewhere? - result_idx = compact.result_idx - resize!(compact.result, result_idx-1) - resize!(compact.result_types, result_idx-1) - resize!(compact.result_lines, result_idx-1) - resize!(compact.result_flags, result_idx-1) - bb = compact.result_bbs[end] - compact.result_bbs[end] = BasicBlock(bb, - StmtRange(first(bb.stmts), result_idx-1)) +function simple_dce!(compact) # Perform simple DCE for unused values extra_worklist = Int[] for (idx, nused) in Iterators.enumerate(compact.used_ssas) - idx >= result_idx && break + idx >= compact.result_idx && break nused == 0 || continue maybe_erase_unused!(extra_worklist, compact, idx) end while !isempty(extra_worklist) maybe_erase_unused!(extra_worklist, compact, pop!(extra_worklist)) end +end + +function non_dce_finish!(compact::IncrementalCompact) + result_idx = compact.result_idx + resize!(compact.result, result_idx-1) + resize!(compact.result_types, result_idx-1) + resize!(compact.result_lines, result_idx-1) + resize!(compact.result_flags, result_idx-1) + just_fixup!(compact) + bb = compact.result_bbs[end] + compact.result_bbs[end] = BasicBlock(bb, + StmtRange(first(bb.stmts), result_idx-1)) +end + +function finish(compact::IncrementalCompact) + non_dce_finish!(compact) + simple_dce!(compact) + complete(compact) +end + +function complete(compact) cfg = CFG(compact.result_bbs, Int[first(bb.stmts) for bb in compact.result_bbs[2:end]]) - return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, NewNode[]) + return IRCode(compact.ir, compact.result, compact.result_types, compact.result_lines, compact.result_flags, cfg, compact.new_new_nodes) end function compact!(code::IRCode) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 228eb87d0e205..2b569aaf65de0 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -101,36 +101,63 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks, du, phin end end -function walk_to_def(compact::IncrementalCompact, @nospecialize(def), intermediaries=IdSet{Int}(), allow_phinode::Bool=true, phi_locs=Tuple{Int, Int}[]) - if !isa(def, SSAValue) - return (def, 0) +function simple_walk(compact::IncrementalCompact, defssa::Union{SSAValue, NewSSAValue, OldSSAValue}, pi_callback=(pi,idx)->nothing) + while true + if isa(defssa, OldSSAValue) && already_inserted(compact, defssa) + rename = compact.ssa_rename[defssa.id] + @assert rename != defssa + if isa(rename, Union{SSAValue, OldSSAValue, NewSSAValue}) + defssa = rename + continue + end + return rename + end + def = compact[defssa] + if isa(def, PiNode) + pi_callback(def, defssa) + if isa(def.val, SSAValue) + defssa = def.val + else + return def.val + end + elseif isa(def, Union{SSAValue, OldSSAValue, NewSSAValue}) + defssa = def + elseif isa(def, Union{PhiNode, Expr}) + return defssa + else + return def + end end - orig_defidx = defidx = def.id +end + +function simple_walk_constraint(compact, defidx, typeconstraint = types(compact)[defidx]) + def = simple_walk(compact, defidx, (pi,_)->(typeconstraint = typeintersect(typeconstraint, pi.typ))) + def, typeconstraint +end + +""" + walk_to_defs(compact, val, intermediaries) + +Starting at `val` walk use-def chains to get all the leaves feeding into +this val (pruning those leaves rules out by path conditions). +""" +function walk_to_defs(compact, defssa, typeconstraint, visited_phinodes=Any[]) # Step 2: Figure out what the struct is defined as - def = compact[defidx] - typeconstraint = types(compact)[defidx] + def = compact[defssa] ## Track definitions through PiNode/PhiNode found_def = false ## Track which PhiNodes, SSAValue intermediaries ## we forwarded through. - while true - if isa(def, PiNode) - push!(intermediaries, defidx) - typeconstraint = typeintersect(typeconstraint, def.typ) - if isa(def.val, SSAValue) - defidx = def.val.id - def = compact[defidx] - else - def = def.val - end - continue - elseif isa(def, FastForward) - append!(phi_locs, def.phi_locs) - def = def.to - elseif isa(def, PhiNode) - # For now, we don't track setfields structs through phi nodes - allow_phinode || break - push!(intermediaries, defidx) + visited = IdSet{Any}() + worklist = Tuple{Any, Any}[] + leaves = Any[] + push!(worklist, (defssa, typeconstraint)) + while !isempty(worklist) + defssa, typeconstraint = pop!(worklist) + push!(visited, defssa) + def = compact[defssa] + if isa(def, PhiNode) + push!(visited_phinodes, defssa) possible_predecessors = let def=def, typeconstraint=typeconstraint collect(Iterators.filter(1:length(def.edges)) do n isassigned(def.values, n) || return false @@ -139,35 +166,33 @@ function walk_to_def(compact::IncrementalCompact, @nospecialize(def), intermedia return typeintersect(edge_typ, typeconstraint) !== Union{} end) end - # For now, only look at unique predecessors - if length(possible_predecessors) == 1 - n = possible_predecessors[1] + for n in possible_predecessors pred = def.edges[n] val = def.values[n] - if isa(val, SSAValue) - push!(phi_locs, (pred, defidx)) - defidx = val.id - def = compact[defidx] - elseif def == val + if isa(val, Union{SSAValue, OldSSAValue, NewSSAValue}) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + if isa(new_def, Union{SSAValue, OldSSAValue, NewSSAValue}) + if !(new_def in visited) + push!(worklist, (new_def, new_constraint)) + end + continue + end + end + if def == val # This shouldn't really ever happen, but # patterns like this can occur in dead code, # so bail out. break else - def = val + push!(leaves, val) end continue end - elseif isa(def, SSAValue) - push!(intermediaries, defidx) - defidx = def.id - def = compact[def.id] - continue + else + push!(leaves, defssa) end - found_def = true - break end - found_def ? (def, defidx) : nothing + leaves end function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr) @@ -183,19 +208,137 @@ struct FastForward phi_locs::Vector{Tuple{Int64, Int64}} end -function getfield_elim_pass!(ir::IRCode, domtree::DomTree) +function already_inserted(compact, old::OldSSAValue) + id = old.id + if id < length(compact.ir.stmts) + return id <= compact.idx + end + id -= length(compact.ir.stmts) + if id < length(compact.ir.new_nodes) + error() + end + id -= length(compact.ir.new_nodes) + @assert id <= length(compact.pending_nodes) + return !(id in compact.pending_perm) +end + +function lift_leaves(compact, stmt, result_t, field, leaves) + # For every leaf, the lifted value + lifted_leaves = IdDict{Any, Any}() + maybe_undef = false + for leaf in leaves + leaf_key = leaf + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + if isa(leaf, OldSSAValue) && already_inserted(compact, leaf) + leaf = compact.ssa_rename[leaf.id] + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + leaf = simple_walk(compact, leaf) + end + if isa(leaf, Union{SSAValue, OldSSAValue, NewSSAValue}) + def = compact[leaf] + else + def = leaf + end + else + def = compact[leaf] + end + if is_tuple_call(compact.ir, def) && isa(field, Int) && 1 <= field < length(def.args) + lifted_leaves[leaf_key] = Ref{Any}(def.args[1+field]) + continue + elseif isexpr(def, :new) + typ = def.typ + if isa(typ, UnionAll) + typ = unwrap_unionall(typ) + end + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + field = try_compute_fieldidx(typ, stmt) + field === nothing && return nothing + if length(def.args) < 1 + field + ftyp = fieldtype(typ, field) + if !isbits(ftyp) + # On this branch, this will be a guaranteed UndefRefError. + # We use the regular undef mechanic to lift this to a boolean slot + maybe_undef = true + lifted_leaves[leaf_key] = nothing + continue + end + # Expand the Expr(:new) to include it's element Expr(:new) nodes up until the one we want + compact[leaf] = nothing + for i = (length(def.args) + 1):(1+field) + ftyp = fieldtype(typ, i - 1) + isbits(ftyp) || return nothing + push!(def.args, insert_node!(compact, leaf, result_t, Expr(:new, ftyp))) + end + compact[leaf] = def + end + lifted_leaves[leaf_key] = Ref{Any}(def.args[1+field]) + continue + else + typ = compact_exprtype(compact, leaf) + if !isa(typ, Const) + # If the leaf is an old ssa value, insert a getfield here + # We will revisit this getfield later when compaction gets + # to the appropriate point. + # N.B.: This can be a bit dangerous because it can lead to + # infinite loops if we accidentally insert a node just ahead + # of where we are + if isa(leaf, OldSSAValue) && (isa(field, Int) || isa(field, Symbol)) + (isa(typ, DataType) && (!typ.abstract)) || return nothing + @assert !typ.mutable + # If there's the potential for an undefref error on access, we cannot insert a getfield + if field > typ.ninitialized && !isbits(fieldtype(typ, field)) + lifted_leaves[leaf] = Ref{Any}(insert_node!(compact, leaf, make_MaybeUndef(result_t), Expr(:call, :unchecked_getfield, SSAValue(leaf.id), field), true)) + maybe_undef = true + else + lifted_leaves[leaf] = Ref{Any}(insert_node!(compact, leaf, result_t, Expr(:call, getfield, SSAValue(leaf.id), field), true)) + end + continue + end + return nothing + end + leaf = typ.val + # Fall through to below + end + elseif isa(leaf, Union{Argument, Expr}) + return nothing + end + isimmutable(leaf) || return nothing + isdefined(leaf, field) || return nothing + val = getfield(leaf, field) + is_inlineable_constant(val) || return nothing + lifted_leaves[leaf_key] = Ref{Any}(quoted(val)) + end + lifted_leaves, maybe_undef +end + +make_MaybeUndef(typ) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ) + +const AnySSAValue = Union{SSAValue, OldSSAValue, NewSSAValue} + +function getfield_elim_pass!(ir::IRCode, domtree) compact = IncrementalCompact(ir) insertions = Vector{Any}() defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}() + lifting_cache = IdDict{Tuple{AnySSAValue, Int}, AnySSAValue}() + revisit_worklist = Int[] + #ndone, nmax = 0, 200 for (idx, stmt) in compact isa(stmt, Expr) || continue + #ndone >= nmax && continue + #ndone += 1 + result_t = compact_exprtype(compact, SSAValue(idx)) is_getfield = false is_ccall = false + is_unchecked = false # Step 1: Check whether the statement we're looking at is a getfield/setfield! if is_known_call(stmt, setfield!, compact) is_setfield = true elseif is_known_call(stmt, getfield, compact) is_getfield = true + elseif isexpr(stmt, :call) && stmt.args[1] == :unchecked_getfield + is_getfield = true + is_unchecked = true elseif isexpr(stmt, :foreigncall) nccallargs = stmt.args[5] new_preserves = Any[] @@ -203,9 +346,10 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) for (pidx, preserved_arg) in enumerate(old_preserves) intermediaries = IdSet() isa(preserved_arg, SSAValue) || continue - def = walk_to_def(compact, preserved_arg, intermediaries, false) - def !== nothing || continue - (def, defidx) = def + def = simple_walk(compact, preserved_arg, (pi, idx)->push!(intermediaries, idx)) + isa(def, SSAValue) || continue + defidx = def.id + def = compact[defidx] if is_tuple_call(compact, def) process_immutable_preserve(new_preserves, compact, def) old_preserves[pidx] = nothing @@ -244,61 +388,118 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) isa(field, QuoteNode) && (field = field.value) isa(field, Union{Int, Symbol}) || continue - intermediaries = IdSet() - phi_locs = Tuple{Int, Int}[] - def = walk_to_def(compact, stmt.args[2], intermediaries, is_getfield, phi_locs) - def !== nothing || continue - (def, defidx) = def + struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2]))) + (isa(struct_typ, DataType) && !struct_typ.mutable) || continue - if !is_getfield - (defidx == 0) && continue - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) + def, typeconstraint = stmt.args[2], struct_typ + + if struct_typ.mutable + isa(def, SSAValue) || continue + intermediaries = IdSet() + def = simple_walk(compact, def, (pi, idx)->push!(intermediaries, idx)) + # Mutable stuff here + isa(def, SSAValue) || continue + mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse())) push!(defuse.defs, idx) union!(mid, intermediaries) continue end - # Step 3: Check if the definition we eventually end up at is either - # a tuple(...) call or Expr(:new) and perform replacement. - if is_tuple_call(compact, def) && isa(field, Int) && 1 <= field < length(def.args) - forwarded = def.args[1+field] - elseif isexpr(def, :new) - typ = def.typ - if isa(typ, UnionAll) - typ = unwrap_unionall(typ) + + if isa(def, Union{OldSSAValue, SSAValue}) + def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint) + end + + visited_phinodes = Any[] + if isa(def, Union{OldSSAValue, SSAValue, NewSSAValue}) && isa(compact[def], PhiNode) + leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes) + else + leaves = [def] + end + + isempty(leaves) && continue + + field = try_compute_fieldidx(struct_typ, stmt) + field === nothing && continue + + r = lift_leaves(compact, stmt, result_t, field, leaves) + r === nothing && continue + lifted_leaves, any_undef = r + + reverse_mapping = IdDict{Any, Any}(ssa => id for (id, ssa) in enumerate(visited_phinodes)) + + if any_undef + result_t = make_MaybeUndef(result_t) + end + + # Insert PhiNodes + lifted_phis = map(visited_phinodes) do item + if (item, field) in keys(lifting_cache) + ssa = lifting_cache[(item, field)] + return (ssa, compact[ssa], false) end - isa(typ, DataType) || continue - if typ.mutable - @assert defidx != 0 - mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse())) - push!(defuse.uses, idx) - union!(mid, intermediaries) - continue + n = PhiNode() + ssa = insert_node!(compact, item, result_t, n) + lifting_cache[(item, field)] = ssa + (ssa, n, true) + end + + # Fix up arguments + for (old_node, (_, new_node, need_argupdate)) in zip(map(x->compact[x], visited_phinodes), lifted_phis) + need_argupdate || continue + for i = 1:length(old_node.edges) + edge = old_node.edges[i] + isassigned(old_node.values, i) || continue + val = old_node.values[i] + if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) + val = simple_walk(compact, val) + end + if val in keys(lifted_leaves) + push!(new_node.edges, edge) + lifted_val = lifted_leaves[val] + if lifted_val === nothing + resize!(new_node.values, length(new_node.values)+1) + continue + end + lifted_val = lifted_val.x + if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue}) + lifted_val = simple_walk(compact, lifted_val) + end + push!(new_node.values, lifted_val) + elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping) + push!(new_node.edges, edge) + push!(new_node.values, lifted_phis[reverse_mapping[val]][1]) + else + # Probably ignored by path condition, skip this + end end - field = try_compute_fieldidx_expr(typ, stmt) - field === nothing && continue - forwarded = def.args[1+field] - else - obj = compact_exprtype(compact, def) - isa(obj, Const) || continue - obj = obj.val - isimmutable(obj) || continue - field = try_compute_fieldidx_expr(typeof(obj), stmt) - field === nothing && continue - isdefined(obj, field) || continue - val = getfield(obj, field) - is_inlineable_constant(val) || continue - forwarded = quoted(val) - end - # Step 4: Remember any phinodes we need to insert - if !isempty(phi_locs) && isa(forwarded, SSAValue) - # TODO: We have have to use BB ids for phi_locs - # to avoid index invalidation. - push!(insertions, idx) - compact[idx] = FastForward(forwarded, phi_locs) + end + + for (_, node) in lifted_phis + count_added_node!(compact, node) + end + + # Fixup the stmt itself + val = stmt.args[2] + if isa(val, Union{SSAValue, OldSSAValue}) + val = simple_walk(compact, val) + end + if val in keys(lifted_leaves) + val = lifted_leaves[val] + @assert val !== nothing + val = val.x else - compact[idx] = forwarded + isa(val, Union{SSAValue, OldSSAValue}) && val in keys(reverse_mapping) + val = lifted_phis[reverse_mapping[val]][1] end + + # Insert the undef check if necessary + if any_undef && !is_unchecked + insert_node!(compact, SSAValue(idx), Nothing, Expr(:undefcheck, :getfield, val)) + end + + compact[idx] = val end + ir = finish(compact) # Now go through any mutable structs and see which ones we can eliminate for (idx, (intermediaries, defuse)) in defuses @@ -389,25 +590,87 @@ function getfield_elim_pass!(ir::IRCode, domtree::DomTree) ir[SSAValue(use)] = new_expr end end - for idx in insertions - # For non-dominating load-store forward, we may have to insert extra phi nodes - # TODO: Can use the domtree to eliminate unnecessary phis, but ok for now - ff = ir.stmts[idx] - ff === nothing && continue # May have been DCE'd if there were no more uses - ff = ff::FastForward - forwarded = ff.to - if isa(forwarded, SSAValue) - forwarded_typ = ir.types[forwarded.id] - for (pred, pos) in reverse!(ff.phi_locs) - node = PhiNode() - push!(node.edges, pred) - push!(node.values, forwarded) - forwarded = insert_node!(ir, pos, forwarded_typ, node) + ir +end + +function adce_erase!(phi_uses, extra_worklist, compact, idx) + if isa(compact.result[idx], PhiNode) + maybe_erase_unused!(extra_worklist, compact, idx, val->phi_uses[val.id]-=1) + else + maybe_erase_unused!(extra_worklist, compact, idx) + end +end + +function count_uses(stmt, uses) + for ur in userefs(stmt) + if isa(ur[], SSAValue) + uses[ur[].id] += 1 + end + end +end + +function mark_phi_cycles(compact, safe_phis, phi) + worklist = Int[] + push!(worklist, phi) + while !isempty(worklist) + phi = pop!(worklist) + push!(safe_phis, phi) + for ur in userefs(compact.result[phi]) + val = ur[] + isa(val, SSAValue) || continue + isa(compact[val], PhiNode) || continue + (val.id in safe_phis) && continue + push!(worklist, val.id) + end + end +end + +function adce_pass!(ir) + phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) + all_phis = Int[] + compact = IncrementalCompact(ir) + for (idx, stmt) in compact + if isa(stmt, PhiNode) + push!(all_phis, idx) + end + end + non_dce_finish!(compact) + for phi in all_phis + count_uses(compact.result[phi], phi_uses) + end + # Perform simple DCE for unused values + extra_worklist = Int[] + for (idx, nused) in Iterators.enumerate(compact.used_ssas) + idx >= compact.result_idx && break + nused == 0 || continue + adce_erase!(phi_uses, extra_worklist, compact, idx) + end + while !isempty(extra_worklist) + adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + end + # Go back and erase any phi cycles + changed = true + while changed + changed = false + safe_phis = IdSet{Int}() + for phi in all_phis + # Save any phi cycles that have non-phi uses + if compact.used_ssas[phi] - phi_uses[phi] != 0 + mark_phi_cycles(compact, safe_phis, phi) + end + end + for phi in all_phis + if !(phi in safe_phis) + push!(extra_worklist, phi) + end + end + while !isempty(extra_worklist) + if adce_erase!(phi_uses, extra_worklist, compact, pop!(extra_worklist)) + changed = true end end - ir.stmts[idx] = forwarded end - ir + complete(compact) end function type_lift_pass!(ir::IRCode) diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index d5322dad9b259..5eb16e43ab3ef 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -97,7 +97,6 @@ function Base.show(io::IO, code::IRCode) maxused = maximum(used) maxsize = length(string(maxused)) end - for idx in eachindex(code.stmts) if !isassigned(code.stmts, idx) # This is invalid, but do something useful rather