diff --git a/README.md b/README.md index 46675f0..383d878 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,12 @@ This analysis works on a lattice called `x::EscapeLattice`, which holds the foll the caller simply because it's passed as call argument - `x.ThrownEscape::Bool`: indicates `x` may escape to somewhere through an exception (possibly as a field) - `x.EscapeSites::BitSet`: records program counters (SSA numbers) where `x` can escape +- `x.FieldSets::Union{Vector{IdSet{Any}},Bool}`: maintains the sets of possible values of fields of `x`: + * `x.FieldSets === false` indicates the fields of `x` isn't analyzed yet + * `x.FieldSets === true` indicates the fields of `x` can't be analyzed, e.g. the type of `x` + is not concrete and thus the number of its fields can't known precisely + * otherwise `x.FieldSets::Vector{IdSet{Any}}` holds all the possible values of each field, + where `x.FieldSets[i]` keeps all possibilities that the `i`th field can be - `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through `setfield!` on argument(s) * `-1` : no escape * `0` : unknown or multiple @@ -30,7 +36,7 @@ An abstract state will be initialized with the bottom(-like) elements: is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning other than it's not analyzed yet (thus it's not formally part of the lattice). -Escape analysis implementation is based on the data-flow algorithm described in the paper [^MM02]. +Escape analysis implementation is based on the data-flow algorithm described in the old paper [^MM02]. The analysis works on the lattice of [`EscapeLattice`](@ref) and transitions lattice elements from the bottom to the top in a _backward_ way, i.e. data flows from usage cites to definitions, until every lattice gets converged to a fixed point by maintaining a (conceptual) working set @@ -39,6 +45,24 @@ The analysis only manages a single global state that tracks `EscapeLattice` of e and SSA statement, but also note that some flow-sensitivity is encoded as program counters recorded in the `EscapeSites` property of each each lattice element. +The analysis also collects alias information using an approach, which is inspired by +the escape analysis algorithm explained in yet another old paper [^JVM05]. +In addition to managing escape lattice elements, the analysis state also maintains an "alias set", +which is implemented as a disjoint set of aliased arguments and SSA statements. +When the fields of object `x` are known precisely (i.e. `x.FieldSets isa Vector{IdSet{Any}}` holds), +the alias set is updated each time `z = getfield(x, y)` is encountered in a way that `z` is +aliased to all values of `x.FieldSets[y]`, so that escape information imposed on `z` will be +propagated to all the aliased values and `z` can be replaced with an aliased value later. +Note that in a case when the fields of object `x` can't known precisely (i.e. `x.FieldSets` is `true`), +when `z = getfield(x, y)` is analyzed, escape information of `z` is propagated to `x` rather +than any of `x`'s fields, which is the most conservative propagation since escape information +imposed on `x` will end up being propagated to all of its fields anyway at definitions of `x` +(i.e. `:new` expression or `setfield!` call). + [^MM02]: _A Graph-Free approach to Data-Flow Analysis_. Markas Mohnen, 2002, April. . + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . diff --git a/src/EscapeAnalysis.jl b/src/EscapeAnalysis.jl index 96cf151..6c0740e 100644 --- a/src/EscapeAnalysis.jl +++ b/src/EscapeAnalysis.jl @@ -53,14 +53,18 @@ import .CC: argextype, singleton_type, IR_FLAG_EFFECT_FREE, - is_meta_expr_head + is_meta_expr_head, + fieldcount_noerror, + try_compute_fieldidx -import Base: == +import Base: ==, IdSet import Base.Meta: isexpr using InteractiveUtils +include("disjoint_set.jl") + let __init_hooks__ = [] global __init__() = foreach(f->f(), __init_hooks__) global register_init_hook!(@nospecialize(f)) = push!(__init_hooks__, f) @@ -151,6 +155,12 @@ A lattice for escape information, which holds the following properties: the caller simply because it's passed as call argument - `x.ThrownEscape::Bool`: indicates `x` may escape to somewhere through an exception (possibly as a field) - `x.EscapeSites::BitSet`: records program counters (SSA numbers) where `x` can escape +- `x.FieldSets::Union{Vector{IdSet{Any}},Bool}`: maintains the sets of possible values of fields of `x`: + * `x.FieldSets === false` indicates the fields of `x` isn't analyzed yet + * `x.FieldSets === true` indicates the fields of `x` can't be analyzed, e.g. the type of `x` + is not concrete and thus the number of its fields can't known precisely + * otherwise `x.FieldSets::Vector{IdSet{Any}}` holds all the possible values of each field, + where `x.FieldSets[i]` keeps all possibilities that the `i`th field can be - `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through `setfield!` on argument(s) * `-1` : no escape * `0` : unknown or multiple @@ -176,7 +186,42 @@ struct EscapeLattice ReturnEscape::Bool ThrownEscape::Bool EscapeSites::BitSet + FieldSets::Union{Vector{IdSet{Any}},Bool} # TODO: ArgEscape::Int + + function EscapeLattice(Analyzed::Bool, + ReturnEscape::Bool, + ThrownEscape::Bool, + EscapeSites::BitSet, + FieldSets, + ) + @nospecialize FieldSets + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + EscapeSites, + FieldSets, + ) + end + function EscapeLattice(x::EscapeLattice, + # non-concrete fields should be passed as default arguments + # in order to avoid allocating non-concrete `NamedTuple`s + FieldSets = x.FieldSets; + Analyzed::Bool = x.Analyzed, + ReturnEscape::Bool = x.ReturnEscape, + ThrownEscape::Bool = x.ThrownEscape, + EscapeSites::BitSet = x.EscapeSites, + ) + @nospecialize FieldSets + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + EscapeSites, + FieldSets, + ) + end end # precomputed default values in order to eliminate computations at each callsite @@ -184,18 +229,20 @@ const EMPTY_ESCAPE_SITES = BitSet() const ARGUMENT_ESCAPE_SITES = BitSet(0) # the constructors -NotAnalyzed() = EscapeLattice(false, false, false, EMPTY_ESCAPE_SITES) # not formally part of the lattice -NoEscape() = EscapeLattice(true, false, false, EMPTY_ESCAPE_SITES) -ReturnEscape(pc::Int) = EscapeLattice(true, true, false, BitSet(pc)) -ThrownEscape(pc::Int) = EscapeLattice(true, false, true, BitSet(pc)) -ArgumentReturnEscape() = EscapeLattice(true, true, false, ARGUMENT_ESCAPE_SITES) +NotAnalyzed() = EscapeLattice(false, false, false, EMPTY_ESCAPE_SITES, false) # not formally part of the lattice +NoEscape() = EscapeLattice(true, false, false, EMPTY_ESCAPE_SITES, false) +ReturnEscape(pc::Int) = EscapeLattice(true, true, false, BitSet(pc), false) +ThrownEscape(pc::Int) = EscapeLattice(true, false, true, BitSet(pc), false) +ArgumentReturnEscape() = EscapeLattice(true, true, false, ARGUMENT_ESCAPE_SITES, true) let - all_escape_sites = BitSet(0:100_000) - global AllEscape() = EscapeLattice(true, true, true, all_escape_sites) + ALL_ESCAPE_SITES = BitSet(0:100_000) + global AllEscape() = EscapeLattice(true, true, true, ALL_ESCAPE_SITES, true) # used for `show` - global AllReturnEscape() = EscapeLattice(true, true, false, all_escape_sites) - global AllThrownEscape() = EscapeLattice(true, false, true, all_escape_sites) + global AllReturnEscape() = EscapeLattice(true, true, false, ALL_ESCAPE_SITES, false) + global AllThrownEscape() = EscapeLattice(true, false, true, ALL_ESCAPE_SITES, false) end +ignore_fieldsets(info::EscapeLattice) = + EscapeLattice(info, info.FieldSets === true ? true : false) # Convenience names for some ⊑ queries export @@ -206,7 +253,7 @@ export has_all_escape, can_elide_finalizer has_not_analyzed(x::EscapeLattice) = x == NotAnalyzed() -has_no_escape(x::EscapeLattice) = x ⊑ NoEscape() +has_no_escape(x::EscapeLattice) = x ⊑ₑ NoEscape() has_return_escape(x::EscapeLattice) = x.ReturnEscape has_return_escape(x::EscapeLattice, pc::Int) = has_return_escape(x) && pc in x.EscapeSites has_thrown_escape(x::EscapeLattice) = x.ThrownEscape @@ -227,6 +274,14 @@ can_elide_finalizer(x::EscapeLattice, pc::Int) = # we need to make sure this `==` operator corresponds to lattice equality rather than object equality, # otherwise `propagate_changes` can't detect the convergence x::EscapeLattice == y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if isa(xf, Bool) + isa(yf, Bool) || return false + xf === yf || return false + else + isa(yf, Bool) && return false + xf == yf || return false + end return x.Analyzed === y.Analyzed && x.ReturnEscape === y.ReturnEscape && x.ThrownEscape === y.ThrownEscape && @@ -235,6 +290,24 @@ x::EscapeLattice == y::EscapeLattice = begin end x::EscapeLattice ⊑ y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if isa(xf, Bool) + xf && yf !== true && return false + else + if isa(yf, Bool) + yf === false && return false + else + xf, yf = xf::Vector{IdSet{Any}}, yf::Vector{IdSet{Any}} + xn, yn = length(xf), length(yf) + xn > yn && return false + for i in 1:xn + xf[i] ⊆ yf[i] || return false + end + end + end + return x ⊑ₑ y +end +x::EscapeLattice ⊑ₑ y::EscapeLattice = begin # partial order excluding the `FieldSets` order if x.Analyzed ≤ y.Analyzed && x.ReturnEscape ≤ y.ReturnEscape && x.ThrownEscape ≤ y.ThrownEscape && @@ -245,14 +318,37 @@ x::EscapeLattice ⊑ y::EscapeLattice = begin return false end x::EscapeLattice ⊏ y::EscapeLattice = x ⊑ y && !(y ⊑ x) +x::EscapeLattice ⊏ₑ y::EscapeLattice = x ⊑ₑ y && !(y ⊑ₑ x) x::EscapeLattice ⋤ y::EscapeLattice = !(y ⊑ x) +x::EscapeLattice ⋤ₑ y::EscapeLattice = !(y ⊑ₑ x) x::EscapeLattice ⊔ y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if xf === true || yf === true + FieldSets = true + elseif xf === false + FieldSets = yf + elseif yf === false + FieldSets = xf + else + xf, yf = xf::Vector{IdSet{Any}}, yf::Vector{IdSet{Any}} + xn, yn = length(xf), length(yf) + nmax, nmin = max(xn, yn), min(xn, yn) + FieldSets = Vector{IdSet{Any}}(undef, nmax) + for i in 1:nmax + if i > nmax + FieldSets[i] = (xn > yn ? xf : yf)[i] + else + FieldSets[i] = xf[i] ∪ yf[i] + end + end + end return EscapeLattice( x.Analyzed | y.Analyzed, x.ReturnEscape | y.ReturnEscape, x.ThrownEscape | y.ThrownEscape, x.EscapeSites ∪ y.EscapeSites, + FieldSets, ) end @@ -262,6 +358,7 @@ x::EscapeLattice ⊓ y::EscapeLattice = begin x.ReturnEscape & y.ReturnEscape, x.ThrownEscape & y.ThrownEscape, x.EscapeSites ∩ y.EscapeSites, + x.FieldSets, # FIXME ) end @@ -275,29 +372,66 @@ Extended lattice that maps arguments and SSA values to escape information repres - `state.arguments::Vector{EscapeLattice}`: escape information about "arguments" – note that "argument" can include both call arguments and slots appearing in analysis frame - `ssavalues::Vector{EscapeLattice}`: escape information about each SSA value +- `aliaset::IntDisjointSet{Int}`: a disjoint set that represents aliased arguments and SSA values """ struct EscapeState arguments::Vector{EscapeLattice} ssavalues::Vector{EscapeLattice} + aliasset::IntDisjointSet{Int} end function EscapeState(nslots::Int, nargs::Int, nstmts::Int) arguments = EscapeLattice[ 1 ≤ i ≤ nargs ? ArgumentReturnEscape() : NotAnalyzed() for i in 1:nslots] ssavalues = EscapeLattice[NotAnalyzed() for _ in 1:nstmts] - return EscapeState(arguments, ssavalues) + aliaset = AliasSet(nslots+nstmts) + return EscapeState(arguments, ssavalues, aliaset) end # we preserve `IRCode` as well just for debugging purpose const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,Tuple{EscapeState,IRCode}}() __clear_escape_cache!() = empty!(GLOBAL_ESCAPE_CACHE) -const Change = Pair{Union{Argument,SSAValue},EscapeLattice} -const Changes = Vector{Change} +const EscapeChange = Pair{Union{Argument,SSAValue},EscapeLattice} +const AliasChange = Pair{Int,Int} +const Changes = Vector{Union{EscapeChange,AliasChange}} + +const AliasSet = IntDisjointSet{Int} +function alias_idx(@nospecialize(x), ir::IRCode) + if isa(x, Argument) + return x.n + elseif isa(x, SSAValue) + return x.id + length(ir.argtypes) + else + return nothing + end +end +function alias_val(idx::Int, ir::IRCode) + n = length(ir.argtypes) + return idx > n ? SSAValue(idx-n) : Argument(idx) +end +function get_aliases(aliasset::AliasSet, @nospecialize(key), ir::IRCode) + idx = alias_idx(key, ir) + idx === nothing && return nothing + root = find_root!(aliasset, idx) + if idx ≠ root || aliasset.ranks[idx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Union{Argument,SSAValue}[] + for i in 1:length(aliasset.parents) + if aliasset.parents[i] == root + push!(aliases, alias_val(i, ir)) + end + end + return aliases + else + return nothing + end +end """ find_escapes(ir::IRCode, nargs::Int) -> EscapeState -Escape analysis implementation is based on the data-flow algorithm described in the paper [^MM02]. +Escape analysis implementation is based on the data-flow algorithm described in the old paper [^MM02]. The analysis works on the lattice of [`EscapeLattice`](@ref) and transitions lattice elements from the bottom to the top in a _backward_ way, i.e. data flows from usage cites to definitions, until every lattice gets converged to a fixed point by maintaining a (conceptual) working set @@ -306,9 +440,27 @@ The analysis only manages a single global state that tracks `EscapeLattice` of e and SSA statement, but also note that some flow-sensitivity is encoded as program counters recorded in the `EscapeSites` property of each each lattice element. +The analysis also collects alias information using an approach, which is inspired by +the escape analysis algorithm explained in yet another old paper [^JVM05]. +In addition to managing escape lattice elements, the analysis state also maintains an "alias set", +which is implemented as a disjoint set of aliased arguments and SSA statements. +When the fields of object `x` are known precisely (i.e. `x.FieldSets isa Vector{IdSet{Any}}` holds), +the alias set is updated each time `z = getfield(x, y)` is encountered in a way that `z` is +aliased to all values of `x.FieldSets[y]`, so that escape information imposed on `z` will be +propagated to all the aliased values and `z` can be replaced with an aliased value later. +Note that in a case when the fields of object `x` can't known precisely (i.e. `x.FieldSets` is `true`), +when `z = getfield(x, y)` is analyzed, escape information of `z` is propagated to `x` rather +than any of `x`'s fields, which is the most conservative propagation since escape information +imposed on `x` will end up being propagated to all of its fields anyway at definitions of `x` +(i.e. `:new` expression or `setfield!` call). + [^MM02]: _A Graph-Free approach to Data-Flow Analysis_. Markas Mohnen, 2002, April. . + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . """ function find_escapes(ir::IRCode, nargs::Int) (; stmts, sptypes, argtypes) = ir @@ -335,7 +487,7 @@ function find_escapes(ir::IRCode, nargs::Int) has_changes = escape_call!(ir, pc, stmt.args, state, changes) if !is_effect_free for x in stmt.args - add_change!(x, ir, ThrownEscape(pc), changes) + add_escape_change!(x, ir, ThrownEscape(pc), changes) end else has_changes || continue @@ -347,12 +499,14 @@ function find_escapes(ir::IRCode, nargs::Int) elseif head === :(=) lhs, rhs = stmt.args if isa(lhs, GlobalRef) # global store - add_change!(rhs, ir, AllEscape(), changes) + add_escape_change!(rhs, ir, AllEscape(), changes) + else + invalid_escape_assignment!(ir, pc) end elseif head === :foreigncall escape_foreigncall!(ir, pc, stmt.args, state, changes) elseif head === :throw_undef_if_not # XXX when is this expression inserted ? - add_change!(stmt.args[1], ir, ThrownEscape(pc), changes) + add_escape_change!(stmt.args[1], ir, ThrownEscape(pc), changes) elseif is_meta_expr_head(head) # meta expressions doesn't account for any usages continue @@ -384,15 +538,15 @@ function find_escapes(ir::IRCode, nargs::Int) continue else for x in stmt.args - add_change!(x, ir, AllEscape(), changes) + add_escape_change!(x, ir, AllEscape(), changes) end end elseif isa(stmt, GlobalRef) # global load - add_change!(SSAValue(pc), ir, AllEscape(), changes) + add_escape_change!(SSAValue(pc), ir, AllEscape(), changes) elseif isa(stmt, PiNode) if isdefined(stmt, :val) info = state.ssavalues[pc] - add_change!(stmt.val, ir, info, changes) + add_escape_change!(stmt.val, ir, info, changes) end elseif isa(stmt, PhiNode) @inline escape_backedges!(ir, pc, stmt.values, state, changes) @@ -401,11 +555,11 @@ function find_escapes(ir::IRCode, nargs::Int) elseif isa(stmt, UpsilonNode) if isdefined(stmt, :val) info = state.ssavalues[pc] - add_change!(stmt.val, ir, info, changes) + add_escape_change!(stmt.val, ir, info, changes) end elseif isa(stmt, ReturnNode) if isdefined(stmt, :val) - add_change!(stmt.val, ir, ReturnEscape(pc), changes) + add_escape_change!(stmt.val, ir, ReturnEscape(pc), changes) end else @assert stmt isa GotoNode || stmt isa GotoIfNot || isnothing(stmt) # TODO remove me @@ -414,7 +568,7 @@ function find_escapes(ir::IRCode, nargs::Int) isempty(changes) && continue - anyupdate |= propagate_changes!(state, changes) + anyupdate |= propagate_changes!(state, changes, ir) empty!(changes) end @@ -426,45 +580,82 @@ function find_escapes(ir::IRCode, nargs::Int) end # propagate changes, and check convergence -function propagate_changes!(state::EscapeState, changes::Changes) +function propagate_changes!(state::EscapeState, changes::Changes, ir::IRCode) local anychanged = false - for (x, info) in changes - if isa(x, Argument) - old = state.arguments[x.n] - new = old ⊔ info - if old ≠ new - state.arguments[x.n] = new - anychanged |= true + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(state, change) + x, info = change + aliases = get_aliases(state.aliasset, x, ir) + if aliases !== nothing + for alias in aliases + morechange = EscapeChange(alias, info) + anychanged |= propagate_escape_change!(state, morechange) + end end else - x = x::SSAValue - old = state.ssavalues[x.id] - new = old ⊔ info - if old ≠ new - state.ssavalues[x.id] = new - anychanged |= true - end + anychanged |= propagate_alias_change!(state, change) end end return anychanged end -function add_change!(@nospecialize(x), ir::IRCode, info::EscapeLattice, changes::Changes) +function propagate_escape_change!(state::EscapeState, change::EscapeChange) + x, info = change + if isa(x, Argument) + old = state.arguments[x.n] + new = old ⊔ info + if old ≠ new + state.arguments[x.n] = new + return true + end + else + x = x::SSAValue + old = state.ssavalues[x.id] + new = old ⊔ info + if old ≠ new + state.ssavalues[x.id] = new + return true + end + end + return false +end + +function propagate_alias_change!(state::EscapeState, change::AliasChange) + x, y = change + xroot = find_root!(state.aliasset, x) + yroot = find_root!(state.aliasset, y) + if xroot ≠ yroot + union!(state.aliasset, xroot, yroot) + return true + end + return false +end + +function add_escape_change!(@nospecialize(x), ir::IRCode, info::EscapeLattice, changes::Changes) if isa(x, Argument) || isa(x, SSAValue) if !isbitstype(widenconst(argextype(x, ir, ir.sptypes, ir.argtypes))) - push!(changes, Change(x, info)) + push!(changes, EscapeChange(x, info)) end end end +function add_alias_change!(@nospecialize(x), @nospecialize(y), ir::IRCode, changes::Changes) + xidx = alias_idx(x, ir) + yidx = alias_idx(y, ir) + if xidx !== nothing && yidx !== nothing + push!(changes, AliasChange(xidx, yidx)) + end +end + function escape_backedges!(ir::IRCode, pc::Int, backedges::Vector{Any}, state::EscapeState, changes::Changes) info = state.ssavalues[pc] for i in 1:length(backedges) if isassigned(backedges, i) - add_change!(backedges[i], ir, info, changes) + add_escape_change!(backedges[i], ir, info, changes) end end end @@ -483,12 +674,10 @@ function escape_call!(ir::IRCode, pc::Int, args::Vector{Any}, # if this call hasn't been handled by any of pre-defined handlers, # we escape this call conservatively for i in 2:length(args) - add_change!(args[i], ir, AllEscape(), changes) + add_escape_change!(args[i], ir, AllEscape(), changes) end - return true - else - return true end + return true end function escape_invoke!(ir::IRCode, pc::Int, args::Vector{Any}, @@ -498,7 +687,7 @@ function escape_invoke!(ir::IRCode, pc::Int, args::Vector{Any}, args = args[2:end] if isnothing(cache) for x in args - add_change!(x, ir, AllEscape(), changes) + add_escape_change!(x, ir, AllEscape(), changes) end else (linfostate, _ #=ir::IRCode=#) = cache @@ -509,15 +698,12 @@ function escape_invoke!(ir::IRCode, pc::Int, args::Vector{Any}, arg = args[i] if i ≤ nargs arginfo = linfostate.arguments[i] - else # handle isva signature: COMBAK will this invalid once we encode alias information ? + else # handle isva signature: COMBAK will this be invalid once we take alias information into account ? arginfo = linfostate.arguments[nargs] end - if isempty(arginfo.ReturnEscape) - @eval Main (ir = $ir; linfo = $linfo) - error("invalid escape lattice element returned from inter-procedural context: inspect `Main.ir` and `Main.linfo`") - end + isempty(arginfo.ReturnEscape) && invalid_escape_invoke!(ir, linfo) info = from_interprocedural(arginfo, retinfo, pc) - add_change!(arg, ir, info, changes) + add_escape_change!(arg, ir, info, changes) end end end @@ -531,7 +717,14 @@ function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice, pc else EscapeSites = EMPTY_ESCAPE_SITES end - newarginfo = EscapeLattice(true, false, arginfo.ThrownEscape, EscapeSites) + newarginfo = EscapeLattice( + #=Analyzed=#true, #=ReturnEscape=#false, arginfo.ThrownEscape, EscapeSites, + # FIXME implement inter-procedural effect-analysis + # currently, this essentially disables the entire field analysis + # it might be okay from the SROA point of view, since we can't remove the allocation + # as far as it's passed to a callee anyway, but still we may want some field analysis + # in order to stack allocate it + #=FieldSets=#true) if arginfo.EscapeSites === ARGUMENT_ESCAPE_SITES # if this is simply passed as the call argument, we can discard the `ReturnEscape` # information and just propagate the other escape information @@ -543,21 +736,53 @@ function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice, pc end end +@noinline function invalid_escape_invoke!(ir::IRCode, linfo::MethodInstance) + @eval Main (ir = $ir; linfo = $linfo) + error("invalid escape lattice element returned from inter-procedural context: inspect `Main.ir` and `Main.linfo`") +end + +@noinline function invalid_escape_assignment!(ir::IRCode, pc::Int) + @eval Main (ir = $ir; pc = $pc) + error("unexpected assignment found: inspect `Main.pc` and `Main.pc`") +end + function escape_new!(ir::IRCode, pc::Int, args::Vector{Any}, state::EscapeState, changes::Changes) info = state.ssavalues[pc] if info == NotAnalyzed() info = NoEscape() - add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not end + newinfo = add_fieldsets(info, ir.stmts[pc][:type], args) + add_escape_change!(SSAValue(pc), ir, newinfo, changes) # propagate the escape information of this object to all its fields as well # since they can be accessed through the object for i in 2:length(args) - add_change!(args[i], ir, info, changes) + add_escape_change!(args[i], ir, ignore_fieldsets(info), changes) end end +function add_fieldsets(info::EscapeLattice, @nospecialize(typ), args::Vector{Any}) + FieldSets = info.FieldSets + nfields = fieldcount_noerror(typ) + if isa(FieldSets, Bool) && !FieldSets + if nfields === nothing + # abstract type, can't propagate + FieldSets = true + else + FieldSets = IdSet{Any}[IdSet{Any}() for _ in 1:nfields] + end + end + if !isa(FieldSets, Bool) + @assert nfields == length(FieldSets) + for i in 2:length(args) + i-1 > nfields::Int && break # see https://github.com/JuliaLang/julia/issues/43146 + push!(FieldSets[i-1], args[i]) + end + end + return EscapeLattice(info, FieldSets) +end + # escape every argument `(args[6:length(args[3])])` and the name `args[1]` # TODO: we can apply a similar strategy like builtin calls to specialize some foreigncalls function escape_foreigncall!(ir::IRCode, pc::Int, args::Vector{Any}, @@ -567,9 +792,9 @@ function escape_foreigncall!(ir::IRCode, pc::Int, args::Vector{Any}, # if normalize(name) === :jl_gc_add_finalizer_th # # add `FinalizerEscape` ? # end - add_change!(name, ir, ThrownEscape(pc), changes) + add_escape_change!(name, ir, ThrownEscape(pc), changes) for i in 6:5+foreigncall_nargs - add_change!(args[i], ir, ThrownEscape(pc), changes) + add_escape_change!(args[i], ir, ThrownEscape(pc), changes) end end @@ -594,13 +819,13 @@ function escape_builtin!(::typeof(Core.ifelse), ir::IRCode, pc::Int, args::Vecto condt = argextype(cond, ir, ir.sptypes, ir.argtypes) if isa(condt, Const) && (cond = condt.val; isa(cond, Bool)) if cond - add_change!(th, ir, info, changes) + add_escape_change!(th, ir, info, changes) else - add_change!(el, ir, info, changes) + add_escape_change!(el, ir, info, changes) end else - add_change!(th, ir, info, changes) - add_change!(el, ir, info, changes) + add_escape_change!(th, ir, info, changes) + add_escape_change!(el, ir, info, changes) end end @@ -608,7 +833,7 @@ function escape_builtin!(::typeof(typeassert), ir::IRCode, pc::Int, args::Vector length(args) == 3 || return f, obj, typ = args info = state.ssavalues[pc] - add_change!(obj, ir, info, changes) + add_escape_change!(obj, ir, info, changes) end function escape_builtin!(::typeof(tuple), ir::IRCode, pc::Int, args::Vector{Any}, state::EscapeState, changes::Changes) @@ -616,24 +841,143 @@ function escape_builtin!(::typeof(tuple), ir::IRCode, pc::Int, args::Vector{Any} if info == NotAnalyzed() info = NoEscape() end + tupleinfo = add_fieldsets(info, ir.stmts[pc][:type], args) + add_escape_change!(SSAValue(pc), ir, tupleinfo, changes) + # propagate the escape information of this object to all its fields as well, for i in 2:length(args) - add_change!(args[i], ir, info, changes) + add_escape_change!(args[i], ir, ignore_fieldsets(info), changes) end end -# TODO don't propagate escape information to the 1st argument, but propagate information to aliased field function escape_builtin!(::typeof(getfield), ir::IRCode, pc::Int, args::Vector{Any}, state::EscapeState, changes::Changes) - # only propagate info when the field itself is non-bitstype - isbitstype(widenconst(ir.stmts.type[pc])) && return true + length(args) ≥ 3 || return info = state.ssavalues[pc] if info == NotAnalyzed() info = NoEscape() end - for i in 2:length(args) - add_change!(args[i], ir, info, changes) + + obj = args[2] + if isa(obj, SSAValue) + FieldSets = state.ssavalues[obj.id].FieldSets + elseif isa(obj, Argument) + FieldSets = state.arguments[obj.n].FieldSets + else + return + end + if isa(FieldSets, Bool) + if FieldSets + # the field can't be analyzed, escape the object itself (including all its fields) conservatively + add_escape_change!(obj, ir, info, changes) + else + # this field hasn't been analyzed yet, do nothing for now + end + else + typ = argextype(obj, ir, ir.sptypes, ir.argtypes) + if isa(typ, DataType) + fld = args[3] + fldval = try_compute_fieldval(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx !== nothing + for x in FieldSets[fidx] + add_escape_change!(x, ir, info, changes) + add_alias_change!(x, SSAValue(pc), ir, changes) + end + else + # when the field can't be known precisely, + # propagate this escape information to all the fields conservatively + for FieldSet in FieldSets, x in FieldSet + add_escape_change!(x, ir, info, changes) + add_alias_change!(x, SSAValue(pc), ir, changes) + end + end end end +function try_compute_fieldval(ir::IRCode, @nospecialize fld) + # fields are usually literals, handle them manually + if isa(fld, QuoteNode) + fld = fld.value + elseif isa(fld, Int) + else + # try to resolve other constants, e.g. global reference + fld = argextype(fld, ir, ir.sptypes, ir.argtypes) + if isa(fld, Const) + fld = fld.val + else + return nothing + end + end + return isa(fld, Union{Int, Symbol}) ? fld : nothing +end + +function escape_builtin!(::typeof(setfield!), ir::IRCode, pc::Int, args::Vector{Any}, state::EscapeState, changes::Changes) + length(args) ≥ 4 || return + + obj, fld, val = args[2:4] + if isa(obj, SSAValue) + objinfo = state.ssavalues[obj.id] + elseif isa(obj, Argument) + objinfo = state.arguments[obj.n] + else + if isa(obj, GlobalRef) + add_escape_change!(val, ir, AllEscape(), changes) + return + else + # XXX add_escape_change!(val, ir, AllEscape(), changes) ? + @goto add_ssa_escape + end + end + FieldSets = objinfo.FieldSets + typ = argextype(obj, ir, ir.sptypes, ir.argtypes) + if isa(FieldSets, Bool) + if FieldSets + # the field analysis on this object was already done and unsuccessful, + # nothing can't be done here + else + nfields = fieldcount_noerror(typ) + if nfields !== nothing + FieldSets = IdSet{Any}[IdSet{Any}() for _ in 1:nfields] + @goto add_field + else + # fields aren't known precisely + add_escape_change!(obj, ir, EscapeLattice(objinfo, #=FieldSets=#true), changes) + end + end + else + @label add_field # the field sets have been initialized, now add the alias information + if isa(typ, DataType) + fldval = try_compute_fieldval(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx !== nothing + push!(FieldSets[fidx], val) + else + # when the field can't be known precisely, + # add this alias information to all the field sets conservatively + for FieldSet in FieldSets + push!(FieldSet, val) + end + end + # update `obj`'s escape information with the new field sets + add_escape_change!(obj, ir, EscapeLattice(objinfo, FieldSets), changes) + end + # propagate `obj`'s escape information to `val` as well + add_escape_change!(val, ir, ignore_fieldsets(objinfo), changes) + + # propagate escape information imposed on the return value of this `setfield!` + @label add_ssa_escape + ssainfo = state.ssavalues[pc] + if ssainfo == NotAnalyzed() + ssainfo = NoEscape() + end + add_escape_change!(val, ir, ssainfo, changes) +end + # entries # ======= @@ -706,13 +1050,13 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false) getname(x) = string(nameof(x)) if x == NotAnalyzed() name, color = (getname(NotAnalyzed), '◌'), :plain - elseif x == NoEscape() + elseif has_no_escape(x) name, color = (getname(NoEscape), '✓'), :green - elseif NoEscape() ⊏ x ⊑ AllReturnEscape() + elseif NoEscape() ⊏ₑ x ⊑ₑ AllReturnEscape() name, color = (getname(ReturnEscape), '↑'), :cyan - elseif NoEscape() ⊏ x ⊑ AllThrownEscape() + elseif NoEscape() ⊏ₑ x ⊑ₑ AllThrownEscape() name, color = (getname(ThrownEscape), '↓'), :yellow - elseif x == AllEscape() + elseif has_all_escape(x) name, color = (getname(AllEscape), 'X'), :red else name, color = (nothing, '*'), :red @@ -749,6 +1093,9 @@ Base.show(io::IO, result::EscapeResult) = print_with_info(io, result.ir, result. @eval Base.iterate(res::EscapeResult, state=1) = return state > $(fieldcount(EscapeResult)) ? nothing : (getfield(res, state), state+1) +# utitlity queries +get_aliases(result::EscapeResult, @nospecialize(key)) = get_aliases(result.state.aliasset, key, result.ir) + # adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897 function print_with_info(io::IO, ir::IRCode, (; arguments, ssavalues)::EscapeState, linfo::Union{Nothing,MethodInstance}) diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl new file mode 100644 index 0000000..1f99685 --- /dev/null +++ b/src/disjoint_set.jl @@ -0,0 +1,141 @@ +# a disjoint set implementation +# adapted from https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +import Base: + length, + eltype, + union!, + push! + +using Base: OneTo + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/test/runtests.jl b/test/runtests.jl index ad8406c..9219a8b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,6 +13,11 @@ mutable struct MutableCondition cond::Bool end +mutable struct MutableFields{S,T} + field1::S + field2::T +end + @testset "EscapeAnalysis" begin @testset "basics" begin @@ -248,7 +253,7 @@ end # appropriate conversion of inter-procedural context # https://github.com/aviatesk/EscapeAnalysis.jl/issues/7 - @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); Base.inferencebarrier(nothing)) + @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); nothing) let result = @eval M $analyze_escapes() do a = Ref("foo") # shouldn't be "return escape" @@ -265,7 +270,7 @@ end return a end i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type)::Int - r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int @test has_return_escape(result.state.ssavalues[i], r) end @@ -278,7 +283,7 @@ end return ret # alias of `obj` end i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type)::Int - r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int @test has_return_escape(result.state.ssavalues[i], r) end @@ -290,7 +295,7 @@ end return ret # must not alias to `obj` end i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type)::Int - r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int @test !has_return_escape(result.state.ssavalues[i], r) end end @@ -434,8 +439,92 @@ end end @testset "field analysis" begin - let - result = analyze_escapes((String,)) do a # => ReturnEscape + # definitions (:new, :splatnew, setfield!) + # ======================================== + + # escaped object should escape its fields + let result = analyze_escapes((Any,)) do a + global o = MutableSome{Any}(a) + nothing + end + i = findfirst(isT(MutableSome{Any}), result.ir.stmts.type)::Int + @test has_all_escape(result.state.ssavalues[i]) + @test has_all_escape(result.state.arguments[2]) + end + let result = analyze_escapes((Any,)) do a + global t = (a,) + nothing + end + i = findfirst(issubT(Tuple), result.ir.stmts.type)::Int + @test has_all_escape(result.state.ssavalues[i]) + @test has_all_escape(result.state.arguments[2]) + end + let result = analyze_escapes((Any,)) do a + o0 = MutableSome{Any}(a) + global o = MutableSome(o0) + nothing + end + i0 = findfirst(isT(MutableSome{Any}), result.ir.stmts.type)::Int + i1 = findfirst(isT(MutableSome{MutableSome{Any}}), result.ir.stmts.type)::Int + @test has_all_escape(result.state.ssavalues[i0]) + @test has_all_escape(result.state.ssavalues[i1]) + @test has_all_escape(result.state.arguments[2]) + end + let result = analyze_escapes((Any,)) do a + t0 = (a,) + global t = (t0,) + nothing + end + inds = findall(issubT(Tuple), result.ir.stmts.type) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state.ssavalues[i]); end + @test has_all_escape(result.state.arguments[2]) + end + let result = analyze_escapes((Any,)) do a + r = Ref{Any}() + global o = r + r[] = a + nothing + end + i = findfirst(isT(Base.RefValue{Any}), result.ir.stmts.type)::Int + @test has_all_escape(result.state.ssavalues[i]) + @test has_all_escape(result.state.arguments[2]) + end + let result = analyze_escapes((Any,Any)) do a0, a1 + r = Ref{Any}(a0) + global o = r + r[] = a1 + nothing + end + i = findfirst(isT(Base.RefValue{Any}), result.ir.stmts.type)::Int + @test has_all_escape(result.state.ssavalues[i]) + @test_broken !has_all_escape(result.state.arguments[2]) # requires flow-sensitivity ? + @test has_all_escape(result.state.arguments[3]) + end + let result = @eval Module() begin + const Rx = Ref{Any}() + $analyze_escapes((String,)) do s + Rx[] = s # global store => AllEscape + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state.arguments[2]) + end + let result = @eval Module() begin + const Rx = Ref{Any}() + $analyze_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state.arguments[2]) + end + + # usages (getfield) + # ================= + + # field escape doens't escape object if the field is known precisely + let result = analyze_escapes((String,)) do a # => ReturnEscape o = MutableSome(a) # no need to escape f = getfield(o, :value) return f @@ -443,25 +532,81 @@ end i = findfirst(isT(MutableSome{String}), result.ir.stmts.type)::Int r = findfirst(isreturn, result.ir.stmts.inst)::Int @test has_return_escape(result.state.arguments[2], r) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) end - - let - result = analyze_escapes((String,Bool)) do a, cond # => ReturnEscape &&ThrownEscape - o = MutableSome(a) # no need to escape - cond && throw(o) - f = getfield(o, :value) + let result = analyze_escapes((String,)) do a # => ReturnEscape + t = (a,) # no need to escape + f = t[1] return f end - i = findfirst(isT(MutableSome{String}), result.ir.stmts.type) + i = findfirst(t->t<:Tuple, result.ir.stmts.type)::Int r = findfirst(isreturn, result.ir.stmts.inst)::Int - t = findfirst(isthrow, result.ir.stmts.inst)::Int @test has_return_escape(result.state.arguments[2], r) - @test has_thrown_escape(result.state.ssavalues[i], t) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + let # multiple fields + result = analyze_escapes((String, String)) do a, b # => ReturnEscape, ReturnEscape + obj = MutableFields(a, b) # => NoEscape + fld1 = obj.field1 + fld2 = obj.field2 + return (fld1, fld2) + end + i = findfirst(isT(MutableFields{String,String}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) # a + @test has_return_escape(result.state.arguments[3], r) # b + @test !has_return_escape(result.state.ssavalues[i], r) end - let + # should work with `setfield!` + let result = analyze_escapes((String,)) do a # => ReturnEscape + o = Ref{String}() # no need to escape + o[] = a + f = o[] + return f + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) + r = findfirst(isreturn, result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + let result = analyze_escapes((String, Symbol)) do a, fld # => ReturnEscape + o = Ref{String}("foo") # no need to escape + setfield!(o, fld, a) + f = o[] + return f + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) + r = findfirst(isreturn, result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + let result = analyze_escapes((String, Symbol)) do a, fld + obj = MutableFields("foo", "bar") + setfield!(obj, fld, a) + return obj.field1 # this should escape `a` + end + i = findfirst(isT(MutableFields{String,String}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) # a + @test !has_return_escape(result.state.ssavalues[i], r) + end + + let # unknown field + result = analyze_escapes((String, String, Symbol)) do a, b, fld # => ReturnEscape, ReturnEscape + obj = MutableFields(a, b) # => NoEscape + return getfield(obj, fld) + end + i = findfirst(isT(MutableFields{String,String}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) # a + @test has_return_escape(result.state.arguments[3], r) # b + @test !has_return_escape(result.state.ssavalues[i], r) + end + + let # nested object instantiation (not aliased) result = analyze_escapes((String,)) do a # => ReturnEscape o1 = MutableSome(a) # => ReturnEscape o2 = MutableSome(o1) # no need to escape @@ -471,20 +616,95 @@ end i2 = findfirst(isT(MutableSome{MutableSome{String}}), result.ir.stmts.type)::Int r = findfirst(isreturn, result.ir.stmts.inst)::Int @test has_return_escape(result.state.arguments[2], r) - @test has_return_escape(result.state.ssavalues[i1]) - @test_broken !has_return_escape(result.state.ssavalues[i2]) + @test has_return_escape(result.state.ssavalues[i1], r) + @test !has_return_escape(result.state.ssavalues[i2], r) end - let - result = analyze_escapes((Any,Bool)) do a, cond # => ThrownEscape - t = tuple(a) # no need to escape - cond && throw(t[1]) - return nothing + # inter-procedural + # ================ + + let result = @eval Module() begin + @noinline getvalue(obj) = obj.value + $analyze_escapes((String,)) do a # => ReturnEscape + obj = $MutableSome(a) # no need to escape + fld = getvalue(obj) + return fld + end end - i = findfirst(issubT(Tuple), result.ir.stmts.type)::Int - t = findfirst(isthrow, result.ir.stmts.inst)::Int - @test has_thrown_escape(result.state.arguments[2], t) - @test_broken !has_thrown_escape(result.state.ssavalues[i], t) + i = findfirst(isT(MutableSome{String}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken !has_return_escape(result.state.ssavalues[i], r) + end + + # TODO + # flow-sensitivity + # ================ + let result = analyze_escapes((Any,Any)) do a1, a2 + r = Ref{Any}() + r[] = a1 + r[] = a2 + return r[] + end + i = findfirst(isT(Base.RefValue{Any}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test_broken !has_return_escape(result.state.arguments[2], r) + @test has_return_escape(result.state.arguments[3], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + let result = analyze_escapes((Any,Any,Bool)) do a1, a2, cond + r = Ref{Any}() + if cond + r[] = a1 + return r[] + else + r[] = a2 + return nothing + end + end + i = findfirst(isT(Base.RefValue{Any}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) + @test_broken !has_return_escape(result.state.arguments[3], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + + # end to end + # ========== + + let # `popfirst!(InvasiveLinkedList{Task})` within this `println` used to cause infinite loop ... + result = analyze_escapes((String,)) do a + println(a) + nothing + end + @test true + end +end + +@testset "alias analysis" begin + let result = analyze_escapes((String,)) do a # => ReturnEscape + o1 = MutableSome(a) # => NoEscape + o2 = MutableSome(o1) # => NoEscape + o1′ = getfield(o2, :value) # o1 + a′ = getfield(o1′, :value) # a + return a′ + end + i1 = findfirst(isT(MutableSome{String}), result.ir.stmts.type)::Int + i2 = findfirst(isT(MutableSome{MutableSome{String}}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i1], r) + @test !has_return_escape(result.state.ssavalues[i2], r) + end + + let result = analyze_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type)::Int + r = findfirst(isreturn, result.ir.stmts.inst)::Int + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) end end @@ -548,11 +768,16 @@ end ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore return true end - - test_opt(only(methods(EscapeAnalysis.find_escapes)).sig; function_filter) + test_opt(only(methods(EscapeAnalysis.find_escapes)).sig; + function_filter, + # skip_nonconcrete_calls=false, + ) for m in methods(EscapeAnalysis.escape_builtin!) Base._methods_by_ftype(m.sig, 1, Base.get_world_counter()) === false && continue - test_opt(m.sig; function_filter) + test_opt(m.sig; + function_filter, + # skip_nonconcrete_calls=false, + ) end end