diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index e84f77ae1ea48..cdf1e0cf40f33 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -288,8 +288,7 @@ end function foreigncall_effect_free(stmt::Expr, src::Union{IRCode,IncrementalCompact}) args = stmt.args - name = args[1] - isa(name, QuoteNode) && (name = name.value) + name = normalize(args[1]) isa(name, Symbol) || return false ndims = alloc_array_ndims(name) if ndims !== nothing @@ -315,6 +314,17 @@ function alloc_array_ndims(name::Symbol) return nothing end +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function is_array_alloc(@nospecialize stmt) + isa(stmt, Expr) || return false + if isexpr(stmt, :foreigncall) + name = normalize(stmt.args[1]) + return isa(name, Symbol) && alloc_array_ndims(name) !== nothing + end + return false +end + const FOREIGNCALL_ARG_START = 6 function alloc_array_no_throw(args::Vector{Any}, ndims::Int, src::Union{IRCode,IncrementalCompact}) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index c8037fac648fa..af1bc63bfd7a3 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -557,6 +557,9 @@ function linear_pass!(ir::IRCode) memory_opt = true end continue + elseif is_array_alloc(stmt) + memory_opt = true + continue elseif is_known_call(stmt, getfield, compact) 3 <= length(stmt.args) <= 5 || continue if length(stmt.args) == 5 @@ -687,7 +690,8 @@ function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preser end import .EscapeAnalysis: - EscapeState, EscapeInfo, IndexableFields, LivenessSet, getaliases, LocalUse, LocalDef + EscapeState, EscapeInfo, IndexableFields, IndexableElements, LivenessSet, getaliases, + LocalUse, LocalDef, ArrayInfo """ memory_opt_pass!(ir::IRCode, estate::EscapeState) -> newir::IRCode @@ -711,12 +715,12 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) eliminated = BitSet() revisit = Tuple{#=related=#Vector{SSAValue}, #=Liveness=#LivenessSet}[] all_preserved = true - newpreserves = nothing + newpreserves = IdDict{Int,Vector{Any}}() while !isempty(wset) idx = pop!(wset) ssa = SSAValue(idx) stmt = ir[ssa][:inst] - isexpr(stmt, :new) || continue + isexpr(stmt, :new) || is_array_alloc(stmt) || continue einfo = estate[ssa] is_load_forwardable(einfo) || continue aliases = getaliases(ssa, estate) @@ -730,152 +734,48 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) delete!(wset, alias.id) end end - finfos = (einfo.AliasInfo::IndexableFields).infos - nfields = length(finfos) - - # Partition defuses by field - fdefuses = Vector{FieldDefUse}(undef, nfields) - for i = 1:nfields - finfo = finfos[i] - fdu = FieldDefUse() - for fx in finfo - if isa(fx, LocalUse) - push!(fdu.uses, GetfieldLoad(fx.idx)) # use (getfield call) - else - @assert isa(fx, LocalDef) - push!(fdu.defs, fx.idx) # def (setfield! call or :new expression) - end - end - fdefuses[i] = fdu - end - - Liveness = einfo.Liveness - for livepc in Liveness - livestmt = ir[SSAValue(livepc)][:inst] - if is_known_call(livestmt, Core.ifelse, ir) - # the succeeding domination analysis doesn't account for conditional branching - # by ifelse branching at this moment - @goto next_itr - elseif is_known_call(livestmt, isdefined, ir) - args = livestmt.args - length(args) ≥ 3 || continue - obj = args[2] - isa(obj, SSAValue) || continue - obj in related || continue - fld = args[3] - fldval = try_compute_field(ir, fld) - fldval === nothing && continue - typ = unwrap_unionall(widenconst(argextype(obj, ir))) - isa(typ, DataType) || continue - fldidx = try_compute_fieldidx(typ, fldval) - fldidx === nothing && continue - push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) - elseif isexpr(livestmt, :foreigncall) - # we shouldn't eliminate this use if it's used as a direct argument - args = livestmt.args - nccallargs = length(args[3]::SimpleVector) - for i = 6:(5+nccallargs) - arg = args[i] - isa(arg, SSAValue) && arg in related && @goto next_liveness - end - # this use is preserve, and may be eliminable - for fidx in 1:nfields - push!(fdefuses[fidx].uses, PreserveUse(livepc)) - end - end - @label next_liveness - end - for fidx in 1:nfields - fdu = fdefuses[fidx] - isempty(fdu.uses) && @goto next_use - # check if all uses have safe definitions first, otherwise we should bail out - # since then we may fail to form new ϕ-nodes - ldu = compute_live_ins(ir.cfg, fdu) - if isempty(ldu.live_in_bbs) - phiblocks = Int[] - else - phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) - end - allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) - for use in fdu.uses - isa(use, IsdefinedUse) && continue - if isa(use, PreserveUse) && isempty(fdu.defs) - # nothing to preserve, just ignore this use (may happen when there are unintialized fields) - continue - end - if !has_safe_def(ir, domtree, allblocks, fdu, getuseidx(use)) - all_preserved = false - @goto next_use - end - end - phinodes = IdDict{Int, SSAValue}() - for b in phiblocks - phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), - NewInstruction(PhiNode(), Any)) - end - # Now go through all uses and rewrite them - for use in fdu.uses - if isa(use, GetfieldLoad) - use = getuseidx(use) - ir[SSAValue(use)][:inst] = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - push!(eliminated, use) - elseif all_preserved && isa(use, PreserveUse) - if newpreserves === nothing - newpreserves = IdDict{Int,Vector{Any}}() - end - # record this `use` as replaceable no matter if we preserve new value or not - use = getuseidx(use) - newvalues = get!(()->Any[], newpreserves, use) - isempty(fdu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) - newval = compute_value_for_use( - ir, domtree, allblocks, fdu, phinodes, fidx, use) - if !isbitstype(widenconst(argextype(newval, ir))) - push!(newvalues, newval) - end - elseif isa(use, IsdefinedUse) - use = getuseidx(use) - if has_safe_def(ir, domtree, allblocks, fdu, use) - ir[SSAValue(use)][:inst] = true - push!(eliminated, use) - end - else - throw("unexpected use") - end - end - for b in phiblocks - ϕssa = phinodes[b] - n = ir[ϕssa][:inst]::PhiNode - t = Bottom - for p in ir.cfg.blocks[b].preds - push!(n.edges, p) - v = compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, p) - push!(n.values, v) - if t !== Any - t = tmerge(t, argextype(v, ir)) - end - end - ir[ϕssa][:type] = t - end - @label next_use + AliasInfo = einfo.AliasInfo + if isa(AliasInfo, IndexableFields) + @assert isexpr(stmt, :new) "invalid escape analysis" + all_preserved &= load_forward_object!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo, einfo.Liveness) + else + @assert is_array_alloc(stmt) "invalid escape analysis" + arrayinfo = estate.arrayinfo + @assert isa(arrayinfo, ArrayInfo) && haskey(arrayinfo, idx) "invalid escape analysis" + dims = arrayinfo[idx] + all_preserved &= load_forward_array!(ir, domtree, + eliminated, revisit, + newpreserves, related, + AliasInfo::IndexableElements, einfo.Liveness, dims) end - push!(revisit, (related, Liveness)) - @label next_itr end # remove dead setfield! and :new allocs deadssas = IdSet{SSAValue}() - if all_preserved && newpreserves !== nothing + if all_preserved preserved = keys(newpreserves) else preserved = EMPTY_PRESERVED_SSAS end mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved) for ssa in deadssas + # stmt = ir[ssa][:inst] + # if is_known_call(stmt, setfield!, ir) + # println("[SROA] eliminated setfield!: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif isexpr(stmt, :new) + # println("[SROA] eliminated object alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif is_known_call(stmt, arrayset, ir) + # println("[SROA] eliminated arrayset: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # elseif is_array_alloc(stmt) + # println("[SROA] eliminated array alloc: ", argtypes_to_type(ir.argtypes[1:nargs]), " ", ssa, ": ", stmt) + # end ir[ssa][:inst] = nothing end - if all_preserved && newpreserves !== nothing + if all_preserved deadssas = Int[ssa.id for ssa in deadssas] for (idx, newuses) in newpreserves ir[SSAValue(idx)][:inst] = form_new_preserves( @@ -886,20 +786,291 @@ function memory_opt_pass!(ir::IRCode, estate::EscapeState) return ir end +function load_forward_object!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableFields, Liveness::LivenessSet) + finfos = AliasInfo.infos + nfields = length(finfos) + + # Partition defuses by field + all_preserved = true + fdefuses = Vector{IndexedDefUse}(undef, nfields) + for i = 1:nfields + finfo = finfos[i] + idu = IndexedDefUse() + for fx in finfo + if isa(fx, LocalUse) + push!(idu.uses, LoadUse(fx.idx)) # use (getfield call) + else + @assert isa(fx, LocalDef) + push!(idu.defs, fx.idx) # def (setfield! call or :new expression) + end + end + fdefuses[i] = idu + end + + for livepc in Liveness + livestmt = ir[SSAValue(livepc)][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif is_known_call(livestmt, isdefined, ir) + args = livestmt.args + length(args) ≥ 3 || continue + obj = args[2] + isa(obj, SSAValue) || continue + obj in related || continue + fld = args[3] + fldval = try_compute_field(ir, fld) + fldval === nothing && continue + typ = unwrap_unionall(widenconst(argextype(obj, ir))) + isa(typ, DataType) || continue + fldidx = try_compute_fieldidx(typ, fldval) + fldidx === nothing && continue + push!(fdefuses[fldidx].uses, IsdefinedUse(livepc)) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for fidx in 1:nfields + push!(fdefuses[fidx].uses, PreserveUse(livepc)) + end + end + @label next_liveness + end + + for fidx in 1:nfields + idu = fdefuses[fidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + isa(use, IsdefinedUse) && continue + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, fidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + elseif isa(use, IsdefinedUse) + use = getuseidx(use) + if has_safe_def(ir, domtree, allblocks, idu, use) + ir[SSAValue(use)][:inst] = true + push!(eliminated, use) + end + else + throw("load_forward_object!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + +# TODO is_array_isassigned folding? +function load_forward_array!(ir::IRCode, domtree::DomTree, + eliminated::BitSet, revisit::Vector{Tuple{Vector{SSAValue}, LivenessSet}}, + newpreserves::IdDict{Int,Vector{Any}}, related::Vector{SSAValue}, + AliasInfo::IndexableElements, Liveness::LivenessSet, dims::Vector{Int}) + elminfos = AliasInfo.infos + elmkeys = keys(elminfos) + + # Partition defuses by index + all_preserved = true + edefuses = IdDict{Int,IndexedDefUse}() + for eidx in elmkeys + einfo = elminfos[eidx] + idu = IndexedDefUse() + for ex in einfo + if isa(ex, LocalUse) + push!(idu.uses, LoadUse(ex.idx)) # use (arrayref call) + else + @assert isa(ex, LocalDef) + push!(idu.defs, ex.idx) # def (arrayset call) + end + end + edefuses[eidx] = idu + end + + for livepc in Liveness + ssa = SSAValue(livepc) + livestmt = ir[ssa][:inst] + if is_known_call(livestmt, Core.ifelse, ir) + # the succeeding domination analysis doesn't account for conditional branching + # by ifelse branching at this moment + return false + elseif is_known_call(livestmt, arraylen, ir) + len = 1 + for dim in dims + len *= dim + end + ir[ssa][:inst] = len + push!(eliminated, livepc) + elseif is_known_call(livestmt, arraysize, ir) + length(livestmt.args) ≥ 3 || continue + dim = argextype(livestmt.args[3], ir) + isa(dim, Const) || continue + dim = dim.val + isa(dim, Int) || continue + checkbounds(Bool, dims, dim) || continue + ir[ssa][:inst] = dims[dim] + push!(eliminated, livepc) + elseif isexpr(livestmt, :foreigncall) + # we shouldn't eliminate this use if it's used as a direct argument + args = livestmt.args + nccallargs = length(args[3]::SimpleVector) + for i = 6:(5+nccallargs) + arg = args[i] + isa(arg, SSAValue) && arg in related && @goto next_liveness + end + # this use is preserve, and may be eliminable + for eidx in elmkeys + push!(edefuses[eidx].uses, PreserveUse(livepc)) + end + end + @label next_liveness + end + + for eidx in elmkeys + idu = edefuses[eidx] + isempty(idu.uses) && @goto next_use + # check if all uses have safe definitions first, otherwise we should bail out + # since then we may fail to form new ϕ-nodes + ldu = compute_live_ins(ir.cfg, idu) + if isempty(ldu.live_in_bbs) + phiblocks = Int[] + else + phiblocks = iterated_dominance_frontier(ir.cfg, ldu, domtree) + end + allblocks = sort!(vcat(phiblocks, ldu.def_bbs)) + for use in idu.uses + if isa(use, PreserveUse) && isempty(idu.defs) + # nothing to preserve, just ignore this use (may happen when there are unintialized fields) + continue + end + if !has_safe_def(ir, domtree, allblocks, idu, getuseidx(use)) + all_preserved = false + @goto next_use + end + end + phinodes = IdDict{Int, SSAValue}() + for b in phiblocks + phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), + NewInstruction(PhiNode(), Any)) + end + # Now go through all uses and rewrite them + for use in idu.uses + if isa(use, LoadUse) + use = getuseidx(use) + ir[SSAValue(use)][:inst] = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + push!(eliminated, use) + elseif isa(use, PreserveUse) + all_preserved || continue + # record this `use` as replaceable no matter if we preserve new value or not + use = getuseidx(use) + newvalues = get!(()->Any[], newpreserves, use) + isempty(idu.defs) && continue # nothing to preserve (may happen when there are unintialized fields) + newval = compute_value_for_use( + ir, domtree, allblocks, idu, phinodes, eidx, use) + if !isbitstype(widenconst(argextype(newval, ir))) + push!(newvalues, newval) + end + else + throw("load_forward_array!: unexpected use") + end + end + for b in phiblocks + ϕssa = phinodes[b] + n = ir[ϕssa][:inst]::PhiNode + t = Bottom + for p in ir.cfg.blocks[b].preds + push!(n.edges, p) + v = compute_value_for_block(ir, domtree, allblocks, idu, phinodes, eidx, p) + push!(n.values, v) + if t !== Any + t = tmerge(t, argextype(v, ir)) + end + end + ir[ϕssa][:type] = t + end + @label next_use + end + push!(revisit, (related, Liveness)) + + return all_preserved +end + const EMPTY_PRESERVED_SSAS = keys(IdDict{Int,Vector{Any}}()) const PreservedSets = typeof(EMPTY_PRESERVED_SSAS) function is_load_forwardable(x::EscapeInfo) AliasInfo = x.AliasInfo - return isa(AliasInfo, IndexableFields) + return isa(AliasInfo, IndexableFields) || isa(AliasInfo, IndexableElements) end -struct FieldDefUse +struct IndexedDefUse uses::Vector{Any} defs::Vector{Int} end -FieldDefUse() = FieldDefUse(Any[], Int[]) -struct GetfieldLoad +IndexedDefUse() = IndexedDefUse(Any[], Int[]) +struct LoadUse idx::Int end struct PreserveUse @@ -909,7 +1080,7 @@ struct IsdefinedUse idx::Int end function getuseidx(@nospecialize use) - if isa(use, GetfieldLoad) + if isa(use, LoadUse) return use.idx elseif isa(use, PreserveUse) return use.idx @@ -919,21 +1090,21 @@ function getuseidx(@nospecialize use) throw("getuseidx: unexpected use") end -function compute_live_ins(cfg::CFG, fdu::FieldDefUse) +function compute_live_ins(cfg::CFG, idu::IndexedDefUse) uses = Int[] - for use in fdu.uses + for use in idu.uses isa(use, IsdefinedUse) && continue push!(uses, getuseidx(use)) end - return compute_live_ins(cfg, fdu.defs, uses) + return compute_live_ins(cfg, idu.defs, uses) end # even when the allocation contains an uninitialized field, we try an extra effort to check # if this load at `idx` have any "safe" `setfield!` calls that define the field # try to find function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) dfu === nothing && return false def = dfu[1] def ≠ 0 && return true # found a "safe" definition @@ -949,7 +1120,7 @@ function has_safe_def(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, pred in seen && return false use = last(ir.cfg.blocks[pred].stmts) # NOTE this `use` isn't a load, and so the inclusive condition can be used - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use, true) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use, true) dfu === nothing && return false def = dfu[1] push!(seen, pred) @@ -964,12 +1135,12 @@ end # find the first dominating def for the given use function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, use::Int, inclusive::Bool=false) + idu::IndexedDefUse, use::Int, inclusive::Bool=false) useblock = block_for_inst(ir.cfg, use) curblock = find_curblock(domtree, allblocks, useblock) curblock === nothing && return nothing local def = 0 - for idx in fdu.defs + for idx in idu.defs if block_for_inst(ir.cfg, idx) == curblock if curblock != useblock # Find the last def in this block @@ -998,15 +1169,15 @@ function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int) end function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) - dfu = find_def_for_use(ir, domtree, allblocks, fdu, use) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, use::Int) + dfu = find_def_for_use(ir, domtree, allblocks, idu, use) @assert dfu !== nothing "has_safe_def condition unsatisfied" def, useblock, curblock = dfu if def == 0 if !haskey(phinodes, curblock) # If this happens, we need to search the predecessors for defs. Which # one doesn't matter - if it did, we'd have had a phinode - return compute_value_for_block(ir, domtree, allblocks, fdu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) + return compute_value_for_block(ir, domtree, allblocks, idu, phinodes, fidx, first(ir.cfg.blocks[useblock].preds)) end # The use is the phinode return phinodes[curblock] @@ -1016,11 +1187,11 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end function compute_value_for_block(ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, - fdu::FieldDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) + idu::IndexedDefUse, phinodes::IdDict{Int, SSAValue}, fidx::Int, curblock::Int) curblock = find_curblock(domtree, allblocks, curblock) @assert curblock !== nothing "has_safe_def condition unsatisfied" def = 0 - for stmt in fdu.defs + for stmt in idu.defs if block_for_inst(ir.cfg, stmt) == curblock def = max(def, stmt) end @@ -1032,9 +1203,12 @@ function val_for_def_expr(ir::IRCode, def::Int, fidx::Int) ex = ir[SSAValue(def)][:inst] if isexpr(ex, :new) return ex.args[1+fidx] - else - @assert is_known_call(ex, setfield!, ir) "invalid load forwarding" + elseif is_known_call(ex, setfield!, ir) return ex.args[4] + elseif is_known_call(ex, arrayset, ir) + return ex.args[4] + else + throw("invalid load forwarding") end end @@ -1103,6 +1277,34 @@ function mark_dead_ssas!(ir::IRCode, deadssas::IdSet{SSAValue}, end end return false + elseif is_known_call(stmt, arrayset, ir) + @assert length(stmt.args) ≥ 4 "invalid escape analysis" + ary = stmt.args[3] + val = stmt.args[4] + if isa(ary, SSAValue) + if ary in related + push!(eliminable, ssa) + @goto next_live + end + if isa(val, SSAValue) && val in related + if ary in deadssas + push!(eliminable, ssa) + @goto next_live + end + for new_revisit_idx in wset + if ary in revisit[new_revisit_idx][1] + delete!(wset, new_revisit_idx) + if mark_dead_ssas!(ir, deadssas, revisit, eliminated, preserved, wset, new_revisit_idx) + push!(eliminable, ssa) + @goto next_live + else + return false + end + end + end + end + end + return false elseif isexpr(stmt, :foreigncall) livepc in preserved && @goto next_live return false diff --git a/test/compiler/codegen.jl b/test/compiler/codegen.jl index ec89ac9cd72a4..d21765180a4b9 100644 --- a/test/compiler/codegen.jl +++ b/test/compiler/codegen.jl @@ -548,27 +548,27 @@ end # main use case function f1(cond) val = [1] - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f1, Tuple{Bool}, true, false, false)) # stack allocated objects (JuliaLang/julia#34241) function f3(cond) val = ([1],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f3, Tuple{Bool}, true, false, false)) # unions of immutables (JuliaLang/julia#39501) function f2(cond) val = cond ? 1 : 1f0 - GC.@preserve val begin end + GC.@preserve val begin val end end @test !occursin("llvm.julia.gc_preserve_begin", get_llvm(f2, Tuple{Bool}, true, false, false)) # make sure the fix for the above doesn't regress #34241 function f4(cond) val = cond ? ([1],) : ([1f0],) - GC.@preserve val begin end + GC.@preserve val begin val end end @test occursin("llvm.julia.gc_preserve_begin", get_llvm(f4, Tuple{Bool}, true, false, false)) end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 7793489d0fc2b..0933bd3b49ca9 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -75,17 +75,26 @@ end # SROA # ==== -import Core.Compiler: widenconst - -is_load_forwarded(src::CodeInfo) = !any(iscall((src, getfield)), src.code) -is_scalar_replaced(src::CodeInfo) = - is_load_forwarded(src) && !any(iscall((src, setfield!)), src.code) && !any(isnew, src.code) +import Core.Compiler: widenconst, is_array_alloc + +is_load_forwarded(src::CodeInfo) = + !any(iscall((src, getfield)), src.code) && !any(iscall((src, Core.arrayref)), src.code) +function is_scalar_replaced(src::CodeInfo) + is_load_forwarded(src) || return false + any(iscall((src, setfield!)), src.code) && return false + any(isnew, src.code) && return false + any(iscall((src, Core.arrayset)), src.code) && return false + any(is_array_alloc, src.code) && return false + return true +end function is_load_forwarded(@nospecialize(T), src::CodeInfo) for i in 1:length(src.code) x = src.code[i] if iscall((src, getfield), x) widenconst(argextype(x.args[1], src)) <: T && return false + elseif iscall((src, Core.arrayref), x) + widenconst(argextype(x.args[1], src)) <: T && return false end end return true @@ -98,6 +107,10 @@ function is_scalar_replaced(@nospecialize(T), src::CodeInfo) widenconst(argextype(x.args[1], src)) <: T && return false elseif isnew(x) widenconst(argextype(SSAValue(i), src)) <: T && return false + elseif iscall((src, Core.arrayset), x) + widenconst(argextype(x.args[1], src)) <: T && return false + elseif is_array_alloc(x) + widenconst(argextype(SSAValue(i), src)) <: T && return false end end return true @@ -713,7 +726,7 @@ function mutable_ϕ_elim(x, xs) return r[] end let src = code_typed1(mutable_ϕ_elim, (String, Vector{String})) - @test is_scalar_replaced(src) + @test is_scalar_replaced(Ref{String}, src) xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @@ -852,7 +865,7 @@ function isdefined_elim() return arr end let src = code_typed1(isdefined_elim) - @test is_scalar_replaced(src) + @test count(isnew, src.code) == 0 # eliminates closure constructs end @test isdefined_elim() == Any[] @@ -907,6 +920,121 @@ let # immutable case @test count(isnew, src.code) == 0 end +# array SROA +# ---------- + +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Any[nothing] + a[1] = s + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Any}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1] + end + @test count(isnew, src.code) == 1 +end +let src = code_typed1((String,)) do s + a = Vector{Base.RefValue{String}}(undef, 1) + a[1] = Ref(s) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((String,String)) do s, t + a = Vector{Base.RefValue{String}}(undef, 2) + a[1] = Ref(s) + a[2] = Ref(t) + return a[1][] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Vector{Any}[Any[nothing]] + a[1][1] = s + return a[1][1] + end + @test_broken is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + a = Any[nothing] + if c + a[1] = s + else + a[1] = t + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any,Any,Any,)) do c, s1, s2, t1, t2 + if c + a = Vector{Any}(undef, 2) + a[1] = s1 + a[2] = s2 + else + a = Vector{Any}(undef, 2) + a[1] = t1 + a[2] = t2 + end + return a[1] + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Bool,Any,Any)) do c, s, t + # XXX this implicitly forms tuple to getfield chains + # and SROA on it produces complicated control flow + if c + a = Any[s] + else + a = Any[t] + end + return a[1] + end + @test_broken is_scalar_replaced(src) +end + +# arraylen / arraysize elimination +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1], length(a) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Matrix{Any}(undef, 2, 2) + a[1, 1] = s + return a[1, 1], length(a) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Vector{Any}(undef, 1) + a[1] = s + return a[1], size(a, 1) + end + @test is_scalar_replaced(src) +end +let src = code_typed1((Any,)) do s + a = Matrix{Any}(undef, 2, 2) + a[1, 1] = s + return a[1, 1], size(a) + end + @test is_scalar_replaced(src) +end + # comparison lifting # ==================