From 1aa60f4fc6b37cc2caa13e8b706f37613e8d947b Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Mon, 15 Nov 2021 19:59:19 +0900 Subject: [PATCH] a simple and flow-insensitive alias analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit implements a simple, flow-insensitive alias analysis using an approach inspired by the escape analysis algorithm explained in the old JVM paper [^JVM05]. `EscapeLattice` is extended so that it also keeps track of possible field values. In more detail, `x::EscapeLattice` has the new field called `x.FieldSet::Union{Vector{IdSet{Any}},Bool}`, where: - `x.FieldSets === false` indicates the fields of `x` isn't analyzed yet - `x.FieldSets === true` indicates the fields of `x` can't be analyzed, e.g. the type of `x` is not concrete and thus the number of its fields can't known precisely - otherwise `x.FieldSets::Vector{IdSet{Any}}` holds all the possible values of each field, where `x.FieldSets[i]` keeps all possibilities that the `i`th field can be And now, in addition to managing escape lattice elements, the analysis state also maintains an "alias set" `state.aliasset::IntDisjointSet{Int}`, which is implemented as a disjoint set of aliased arguments and SSA statements. When the fields of object `x` are known precisely (i.e. `x.FieldSets isa Vector{IdSet{Any}}` holds), the alias set is updated each time `z = getfield(x, y)` is encountered in a way that `z` is aliased to all values of `x.FieldSets[y]`, so that escape information imposed on `z` will be propagated to all the aliased values and `z` can be replaced with an aliased value later. Note that in a case when the fields of object `x` can't known precisely (i.e. `x.FieldSets` is `true`), when `z = getfield(x, y)` is analyzed, escape information of `z` is propagated to `x` rather than any of `x`'s fields, which is the most conservative propagation since escape information imposed on `x` will end up being propagated to all of its fields anyway at definitions of `x` (i.e. `:new` expression or `setfield!` call). [^JVM05]: Escape Analysis in the Context of Dynamic Compilation and Deoptimization. Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. . Now this alias analysis should allow us to implement a "stronger" SROA, which eliminates the allocation of `r` within the following code: ```julia julia> result = analyze_escapes((String,)) do s r = Ref(s) broadcast(identity, r) end \#3(_2::String *, _3::Base.RefValue{String} ◌) in Main at REPL[2]:2 2 ↓ 1 ─ %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref 3 ✓ │ %2 = Core.tuple(%1)::Tuple{Base.RefValue{String}} │╻ broadcast ↓ │ %3 = Core.getfield(%2, 1)::Base.RefValue{String} ││ ◌ └── goto #3 if not true ││╻╷ materialize ◌ 2 ─ nothing::Nothing │ * 3 ┄ %6 = Base.getfield(%3, :x)::String │││╻╷╷╷╷ copy ◌ └── goto #4 ││││┃ getindex ◌ 4 ─ goto #5 ││││ ◌ 5 ─ goto #6 │││ ◌ 6 ─ goto #7 ││ ◌ 7 ─ return %6 │ julia> EscapeAnalysis.get_aliases(result.state.aliasset, Core.SSAValue(6), result.ir) 2-element Vector{Union{Core.Argument, Core.SSAValue}}: Core.Argument(2) :(%6) ``` Note that the allocation `%1` isn't analyzed as `ReturnEscape`, still `_2` is analyzed so. --- README.md | 39 +++- src/EscapeAnalysis.jl | 414 +++++++++++++++++++++++++++++++++++------- src/disjoint_set.jl | 141 ++++++++++++++ test/runtests.jl | 130 ++++++++++--- 4 files changed, 630 insertions(+), 94 deletions(-) create mode 100644 src/disjoint_set.jl diff --git a/README.md b/README.md index 9ac4abe..3a149d7 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,12 @@ This analysis works on a lattice called `EscapeLattice`, which holds the followi simply because it's passed as call argument - `x.ThrownEscape::Bool`: indicates it may escape to somewhere through an exception (possibly as a field) - `x.GlobalEscape::Bool`: indicates it may escape to a global space an exception (possibly as a field) +- `x.FieldSets::Union{Vector{IdSet{Any}},Bool}`: maintains the sets of possible values of fields of `x`: + * `x.FieldSets === false` indicates the fields of `x` isn't analyzed yet + * `x.FieldSets === true` indicates the fields of `x` can't be analyzed, e.g. the type of `x` + is not concrete and thus the number of its fields can't known precisely + * otherwise `x.FieldSets::Vector{IdSet{Any}}` holds all the possible values of each field, + where `x.FieldSets[i]` keeps all possibilities that the `i`th field can be - `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through `setfield!` on argument(s) * `-1` : no escape * `0` : unknown or multiple @@ -32,20 +38,33 @@ An abstract state will be initialized with the bottom(-like) elements: is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning other than it's not analyzed yet (thus it's not formally part of the lattice). -Escape analysis implementation is based on the data-flow algorithm described in the paper [^MM02]. +Escape analysis implementation is based on the data-flow algorithm described in the old paper [^MM02]. The analysis works on the lattice of [`EscapeLattice`](@ref) and transitions lattice elements from the bottom to the top in a _backward_ way, i.e. data flows from usage cites to definitions, until every lattice gets converged to a fixed point by maintaining a (conceptual) working set that contains program counters corresponding to remaining SSA statements to be analyzed. -Note that the analysis only manages a single global state, with some flow-sensitivity -encoded as property of `EscapeLattice`. +The analysis only manages a single global state that tracks `EscapeLattice` of each argument +and SSA statement, but also note that some flow-sensitivity is encoded as properties of each +lattice element (like `ReturnEscape`). -[^MM02]: A Graph-Free approach to Data-Flow Analysis. +The analysis also collects alias information using an approach, which is inspired by +the escape analysis algorithm explained in yet another old paper [^JVM05]. +In addition to managing escape lattice elements, the analysis state also maintains an "alias set", +which is implemented as a disjoint set of aliased arguments and SSA statements. +When the fields of object `x` are known precisely (i.e. `x.FieldSets isa Vector{IdSet{Any}}` holds), +the alias set is updated each time `z = getfield(x, y)` is encountered in a way that `z` is +aliased to all values of `x.FieldSets[y]`, so that escape information imposed on `z` will be +propagated to all the aliased values and `z` can be replaced with an aliased value later. +Note that in a case when the fields of object `x` can't known precisely (i.e. `x.FieldSets` is `true`), +when `z = getfield(x, y)` is analyzed, escape information of `z` is propagated to `x` rather +than any of `x`'s fields, which is the most conservative propagation since escape information +imposed on `x` will end up being propagated to all of its fields anyway at definitions of `x` +(i.e. `:new` expression or `setfield!` call). + +[^MM02]: _A Graph-Free approach to Data-Flow Analysis_. Markas Mohnen, 2002, April. - + . -TODO: -- [ ] implement more builtin function handlings, and make escape information more accurate -- [ ] make analysis take into account alias information -- [ ] implement `finalizer` elision optimization ([#17](https://github.com/aviatesk/EscapeAnalysis.jl/issues/17)) -- [ ] circumvent too conservative escapes through potential `throw` calls by copying stack-to-heap on exception ([#15](https://github.com/aviatesk/EscapeAnalysis.jl/issues/15)) +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . diff --git a/src/EscapeAnalysis.jl b/src/EscapeAnalysis.jl index c7febb9..26abe48 100644 --- a/src/EscapeAnalysis.jl +++ b/src/EscapeAnalysis.jl @@ -53,14 +53,18 @@ import .CC: argextype, singleton_type, IR_FLAG_EFFECT_FREE, - is_meta_expr_head + is_meta_expr_head, + fieldcount_noerror, + try_compute_fieldidx -import Base: == +import Base: ==, IdSet import Base.Meta: isexpr using InteractiveUtils +include("disjoint_set.jl") + let __init_hooks__ = [] global __init__() = foreach(f->f(), __init_hooks__) global register_init_hook!(@nospecialize(f)) = push!(__init_hooks__, f) @@ -153,6 +157,12 @@ A lattice for escape information, which holds the following properties: simply because it's passed as call argument - `x.ThrownEscape::Bool`: indicates it may escape to somewhere through an exception (possibly as a field) - `x.GlobalEscape::Bool`: indicates it may escape to a global space an exception (possibly as a field) +- `x.FieldSets::Union{Vector{IdSet{Any}},Bool}`: maintains the sets of possible values of fields of `x`: + * `x.FieldSets === false` indicates the fields of `x` isn't analyzed yet + * `x.FieldSets === true` indicates the fields of `x` can't be analyzed, e.g. the type of `x` + is not concrete and thus the number of its fields can't known precisely + * otherwise `x.FieldSets::Vector{IdSet{Any}}` holds all the possible values of each field, + where `x.FieldSets[i]` keeps all possibilities that the `i`th field can be - `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through `setfield!` on argument(s) * `-1` : no escape * `0` : unknown or multiple @@ -178,6 +188,7 @@ struct EscapeLattice ReturnEscape::BitSet ThrownEscape::Bool GlobalEscape::Bool + FieldSets::Union{Vector{IdSet{Any}},Bool} # TODO: ArgEscape::Int end @@ -185,18 +196,36 @@ end # precompute default values in order to eliminate computations at callsites const NO_RETURN = BitSet() const ARGUMENT_RETURN = BitSet(0) -NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false) # not formally part of the lattice -NoEscape() = EscapeLattice(true, NO_RETURN, false, false) -ReturnEscape(pcs::BitSet) = EscapeLattice(true, pcs, false, false) +NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false, false) # not formally part of the lattice +NoEscape() = EscapeLattice(true, NO_RETURN, false, false, false) +ReturnEscape(pcs::BitSet) = EscapeLattice(true, pcs, false, false, false) ReturnEscape(pc::Int) = ReturnEscape(BitSet(pc)) ArgumentReturnEscape() = ReturnEscape(ARGUMENT_RETURN) -ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false) -GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true) +ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false, false) +GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true, false) let all_return = BitSet(0:100_000) global AllReturnEscape() = ReturnEscape(all_return) # used for `show` - global AllEscape() = EscapeLattice(true, all_return, true, true) + global AllEscape() = EscapeLattice(true, all_return, true, true, true) +end + +# utility constructor +function EscapeLattice(x::EscapeLattice; + Analyzed::Bool = x.Analyzed, + ReturnEscape::BitSet = x.ReturnEscape, + ThrownEscape::Bool = x.ThrownEscape, + GlobalEscape::Bool = x.GlobalEscape, + FieldSets::Union{Vector{IdSet{Any}},Bool} = x.FieldSets, + ) + return EscapeLattice( + Analyzed, + ReturnEscape, + ThrownEscape, + GlobalEscape, + FieldSets, + ) end +ignore_fieldsets(info::EscapeLattice) = EscapeLattice(info; FieldSets = false) # Convenience names for some ⊑ queries export @@ -208,7 +237,7 @@ export has_all_escape, can_elide_finalizer has_not_analyzed(x::EscapeLattice) = x == NotAnalyzed() -has_no_escape(x::EscapeLattice) = x ⊑ NoEscape() +has_no_escape(x::EscapeLattice) = ignore_fieldsets(x) ⊑ NoEscape() has_return_escape(x::EscapeLattice) = !isempty(x.ReturnEscape) has_return_escape(x::EscapeLattice, pc::Int) = pc in x.ReturnEscape has_thrown_escape(x::EscapeLattice) = x.ThrownEscape @@ -232,6 +261,14 @@ end # we need to make sure this `==` operator corresponds to lattice equality rather than object equality, # otherwise `propagate_changes` can't detect the convergence x::EscapeLattice == y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if isa(xf, Bool) + isa(yf, Bool) || return false + xf === yf || return false + else + isa(yf, Bool) && return false + xf == yf || return false + end return x.Analyzed === y.Analyzed && x.ReturnEscape == y.ReturnEscape && x.ThrownEscape === y.ThrownEscape && @@ -240,6 +277,21 @@ x::EscapeLattice == y::EscapeLattice = begin end x::EscapeLattice ⊑ y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if isa(xf, Bool) + xf && yf !== true && return false + else + if isa(yf, Bool) + yf === false && return false + else + xf, yf = xf::Vector{IdSet{Any}}, yf::Vector{IdSet{Any}} + xn, yn = length(xf), length(yf) + xn > yn && return false + for i in 1:xn + xf[i] ⊆ yf[i] || return false + end + end + end if x.Analyzed ≤ y.Analyzed && x.ReturnEscape ⊆ y.ReturnEscape && x.ThrownEscape ≤ y.ThrownEscape && @@ -253,11 +305,32 @@ x::EscapeLattice ⊏ y::EscapeLattice = x ⊑ y && !(y ⊑ x) x::EscapeLattice ⋤ y::EscapeLattice = !(y ⊑ x) x::EscapeLattice ⊔ y::EscapeLattice = begin + xf, yf = x.FieldSets, y.FieldSets + if xf === true || yf === true + FieldSets = true + elseif xf === false + FieldSets = yf + elseif yf === false + FieldSets = xf + else + xf, yf = xf::Vector{IdSet{Any}}, yf::Vector{IdSet{Any}} + xn, yn = length(xf), length(yf) + nmax, nmin = max(xn, yn), min(xn, yn) + FieldSets = Vector{IdSet{Any}}(undef, nmax) + for i in 1:nmax + if i > nmax + FieldSets[i] = (xn > yn ? xf : yf)[i] + else + FieldSets[i] = xf[i] ∪ yf[i] + end + end + end return EscapeLattice( x.Analyzed | y.Analyzed, x.ReturnEscape ∪ y.ReturnEscape, x.ThrownEscape | y.ThrownEscape, x.GlobalEscape | y.GlobalEscape, + FieldSets, ) end @@ -267,6 +340,7 @@ x::EscapeLattice ⊓ y::EscapeLattice = begin x.ReturnEscape ∩ y.ReturnEscape, x.ThrownEscape & y.ThrownEscape, x.GlobalEscape & y.GlobalEscape, + x.FieldSets, # FIXME ) end @@ -280,39 +354,95 @@ Extended lattice that maps arguments and SSA values to escape information repres - `state.arguments::Vector{EscapeLattice}`: escape information about "arguments" – note that "argument" can include both call arguments and slots appearing in analysis frame - `ssavalues::Vector{EscapeLattice}`: escape information about each SSA value +- `aliaset::IntDisjointSet{Int}`: a disjoint set that represents aliased arguments and SSA values """ struct EscapeState arguments::Vector{EscapeLattice} ssavalues::Vector{EscapeLattice} + aliasset::IntDisjointSet{Int} end function EscapeState(nslots::Int, nargs::Int, nstmts::Int) arguments = EscapeLattice[ 1 ≤ i ≤ nargs ? ArgumentReturnEscape() : NotAnalyzed() for i in 1:nslots] ssavalues = EscapeLattice[NotAnalyzed() for _ in 1:nstmts] - return EscapeState(arguments, ssavalues) + aliaset = AliasSet(nslots+nstmts) + return EscapeState(arguments, ssavalues, aliaset) end # we preserve `IRCode` as well just for debugging purpose const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,Tuple{EscapeState,IRCode}}() __clear_escape_cache!() = empty!(GLOBAL_ESCAPE_CACHE) -const Change = Pair{Union{Argument,SSAValue},EscapeLattice} -const Changes = Vector{Change} +const EscapeChange = Pair{Union{Argument,SSAValue},EscapeLattice} +const AliasChange = Pair{Int,Int} +const Changes = Vector{Union{EscapeChange,AliasChange}} + +const AliasSet = IntDisjointSet{Int} +function alias_index(@nospecialize(x), ir::IRCode) + if isa(x, Argument) + return x.n + elseif isa(x, SSAValue) + return x.id + length(ir.argtypes) + else + return nothing + end +end +function alias_val(idx::Int, ir::IRCode) + n = length(ir.argtypes) + return idx > n ? SSAValue(idx-n) : Argument(idx) +end +function get_aliases(aliasset::AliasSet, @nospecialize(key), ir::IRCode) + idx = alias_index(key, ir) + idx === nothing && return nothing + root = find_root!(aliasset, idx) + if idx ≠ root || aliasset.ranks[idx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Union{Argument,SSAValue}[] + for i in 1:length(aliasset.parents) + if aliasset.parents[i] == root + push!(aliases, alias_val(i, ir)) + end + end + return aliases + else + return nothing + end +end """ find_escapes(ir::IRCode, nargs::Int) -> EscapeState -Escape analysis implementation is based on the data-flow algorithm described in the paper [^MM02]. +Escape analysis implementation is based on the data-flow algorithm described in the old paper [^MM02]. The analysis works on the lattice of [`EscapeLattice`](@ref) and transitions lattice elements from the bottom to the top in a _backward_ way, i.e. data flows from usage cites to definitions, until every lattice gets converged to a fixed point by maintaining a (conceptual) working set that contains program counters corresponding to remaining SSA statements to be analyzed. -Note that the analysis only manages a single global state, with some flow-sensitivity -encoded as property of `EscapeLattice`. +The analysis only manages a single global state that tracks `EscapeLattice` of each argument +and SSA statement, but also note that some flow-sensitivity is encoded as properties of each +lattice element (like `ReturnEscape`). + +The analysis also collects alias information using an approach, which is inspired by +the escape analysis algorithm explained in yet another old paper [^JVM05]. +In addition to managing escape lattice elements, the analysis state also maintains an "alias set", +which is implemented as a disjoint set of aliased arguments and SSA statements. +When the fields of object `x` are known precisely (i.e. `x.FieldSets isa Vector{IdSet{Any}}` holds), +the alias set is updated each time `z = getfield(x, y)` is encountered in a way that `z` is +aliased to all values of `x.FieldSets[y]`, so that escape information imposed on `z` will be +propagated to all the aliased values and `z` can be replaced with an aliased value later. +Note that in a case when the fields of object `x` can't known precisely (i.e. `x.FieldSets` is `true`), +when `z = getfield(x, y)` is analyzed, escape information of `z` is propagated to `x` rather +than any of `x`'s fields, which is the most conservative propagation since escape information +imposed on `x` will end up being propagated to all of its fields anyway at definitions of `x` +(i.e. `:new` expression or `setfield!` call). [^MM02]: A Graph-Free approach to Data-Flow Analysis. Markas Mohnen, 2002, April. + +[^JVM05]: Escape Analysis in the Context of Dynamic Compilation and Deoptimization. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . """ function find_escapes(ir::IRCode, nargs::Int) (; stmts, sptypes, argtypes) = ir @@ -346,10 +476,8 @@ function find_escapes(ir::IRCode, nargs::Int) end elseif head === :invoke escape_invoke!(stmt.args, pc, state, ir, changes) - elseif head === :new + elseif head === :new || head === :splatnew escape_new!(stmt.args, pc, state, ir, changes) - elseif head === :splatnew - escape_new!(stmt.args, pc, state, ir, changes, true) elseif head === :(=) lhs, rhs = stmt.args if isa(lhs, GlobalRef) # global store @@ -443,7 +571,7 @@ function find_escapes(ir::IRCode, nargs::Int) isempty(changes) && continue - anyupdate |= propagate_changes!(state, changes) + anyupdate |= propagate_changes!(state, changes, ir) empty!(changes) end @@ -455,47 +583,76 @@ function find_escapes(ir::IRCode, nargs::Int) end # propagate changes, and check convergence -function propagate_changes!(state::EscapeState, changes::Changes) +function propagate_changes!(state::EscapeState, changes::Changes, ir::IRCode) local anychanged = false - for (x, info) in changes - if isa(x, Argument) - old = state.arguments[x.n] - new = old ⊔ info - if old ≠ new - state.arguments[x.n] = new - anychanged |= true + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(state, change) + x, info = change + aliases = get_aliases(state.aliasset, x, ir) + if aliases !== nothing + for alias in aliases + morechange = EscapeChange(alias, info) + anychanged |= propagate_escape_change!(state, morechange) + end end else - x = x::SSAValue - old = state.ssavalues[x.id] - new = old ⊔ info - if old ≠ new - state.ssavalues[x.id] = new - anychanged |= true - end + anychanged |= propagate_alias_change!(state, change) end end return anychanged end -# function normalize(@nospecialize(x)) -# if isa(x, QuoteNode) -# return x.value -# else -# return x -# end -# end +function propagate_escape_change!(state::EscapeState, change::EscapeChange) + x, info = change + if isa(x, Argument) + old = state.arguments[x.n] + new = old ⊔ info + if old ≠ new + state.arguments[x.n] = new + return true + end + else + x = x::SSAValue + old = state.ssavalues[x.id] + new = old ⊔ info + if old ≠ new + state.ssavalues[x.id] = new + return true + end + end + return false +end + +function propagate_alias_change!(state::EscapeState, change::AliasChange) + x, y = change + xroot = find_root!(state.aliasset, x) + yroot = find_root!(state.aliasset, y) + if xroot ≠ yroot + union!(state.aliasset, xroot, yroot) + return true + end + return false +end function add_change!(@nospecialize(x), ir::IRCode, info::EscapeLattice, changes::Changes) if isa(x, Argument) || isa(x, SSAValue) if !isbitstype(widenconst(argextype(x, ir, ir.sptypes, ir.argtypes))) - push!(changes, Change(x, info)) + push!(changes, EscapeChange(x, info)) end end end +function add_alias!(@nospecialize(x), @nospecialize(y), ir::IRCode, changes::Changes) + xidx = alias_index(x, ir) + yidx = alias_index(y, ir) + if xidx !== nothing && yidx !== nothing + push!(changes, AliasChange(xidx, yidx)) + end +end + function escape_invoke!(args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes) linfo = first(args)::MethodInstance @@ -531,7 +688,8 @@ end # context of the caller frame using the escape information imposed on the return value (`retinfo`) function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice) ar = arginfo.ReturnEscape - newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape) + newarginfo = EscapeLattice( + true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, arginfo.FieldSets) # FIXME if ar == ARGUMENT_RETURN # if this is simply passed as the call argument, we can discard the `ReturnEscape` # information and just propagate the other escape information @@ -567,22 +725,38 @@ end function escape_new!(args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes, - splat_new::Bool = false) + ) info = state.ssavalues[pc] if info == NotAnalyzed() info = NoEscape() - add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not end - # we need to propagate escape information of this object to its fields as well, - # since they can be accessed through the object - if splat_new - # splatnew passes field values using a single tuple (args[2]) - add_change!(args[2], ir, info, changes) - else + newinfo = add_fieldsets(info, ir.stmts[pc][:type], args) + add_change!(SSAValue(pc), ir, newinfo, changes) + + # propagate the escape information of this object to all its fields as well + for i in 2:length(args) + add_change!(args[i], ir, ignore_fieldsets(info), changes) + end +end + +function add_fieldsets(info::EscapeLattice, @nospecialize(typ), args::Vector{Any}) + FieldSets = info.FieldSets + if isa(FieldSets, Bool) && !FieldSets + nfields = fieldcount_noerror(typ) + if nfields === nothing + # abstract type, can't propagate + FieldSets = true + else + FieldSets = IdSet{Any}[IdSet{Any}() for _ in 1:nfields] + end + end + if !isa(FieldSets, Bool) + @assert fieldcount_noerror(typ) == length(FieldSets) for i in 2:length(args) - add_change!(args[i], ir, info, changes) + push!(FieldSets[i-1], args[i]) end end + return EscapeLattice(info; FieldSets) end # NOTE error cases will be handled in `find_escapes` anyway, so we don't need to take care of them below @@ -626,22 +800,140 @@ function escape_builtin!(::typeof(tuple), args::Vector{Any}, pc::Int, state::Esc if info == NotAnalyzed() info = NoEscape() end + tupleinfo = add_fieldsets(info, ir.stmts[pc][:type], args) + add_change!(SSAValue(pc), ir, tupleinfo, changes) + # propagate the escape information of this object to all its fields as well, for i in 2:length(args) - add_change!(args[i], ir, info, changes) + add_change!(args[i], ir, ignore_fieldsets(info), changes) end end -# TODO don't propagate escape information to the 1st argument, but propagate information to aliased field function escape_builtin!(::typeof(getfield), args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes) # only propagate info when the field itself is non-bitstype isbitstype(widenconst(ir.stmts.type[pc])) && return true + length(args) ≥ 3 || return info = state.ssavalues[pc] if info == NotAnalyzed() info = NoEscape() end - for i in 2:length(args) - add_change!(args[i], ir, info, changes) + + obj = args[2] + if isa(obj, SSAValue) + FieldSets = state.ssavalues[obj.id].FieldSets + elseif isa(obj, Argument) + FieldSets = state.arguments[obj.n].FieldSets + else + return + end + if isa(FieldSets, Bool) + if FieldSets + # the field can't be analyzed, escape the object itself (including all its fields) conservatively + add_change!(obj, ir, info, changes) + else + # this field hasn't been analyzed yet, do nothing for now + end + else + typ = argextype(obj, ir, ir.sptypes, ir.argtypes) + if isa(typ, DataType) + fld = args[3] + fldval = try_compute_fieldval(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx !== nothing + for x in FieldSets[fidx] + add_change!(x, ir, info, changes) + add_alias!(x, SSAValue(pc), ir, changes) + end + else + # when the field can't be known precisely, + # propagate this escape information to all the fields conservatively + for FieldSet in FieldSets, x in FieldSet + add_change!(x, ir, info, changes) + add_alias!(x, SSAValue(pc), ir, changes) + end + end + end +end + +function try_compute_fieldval(ir::IRCode, @nospecialize fld) + # fields are usually literals, handle them manually + if isa(fld, QuoteNode) + fld = fld.value + elseif isa(fld, Int) + else + # try to resolve other constants, e.g. global reference + fld = argextype(fld, ir, ir.sptypes, ir.argtypes) + if isa(fld, Const) + fld = fld.val + else + return nothing + end + end + return isa(fld, Union{Int, Symbol}) ? fld : nothing +end + +function escape_builtin!(::typeof(setfield!), args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes) + # only propagate info when the field itself is non-bitstype + isbitstype(widenconst(ir.stmts.type[pc])) && return + length(args) ≥ 4 || return + + obj, fld, val = args[2:4] + if isa(obj, SSAValue) + objinfo = state.ssavalues[obj.id] + elseif isa(obj, Argument) + objinfo = state.arguments[obj.n] + else + # TODO add_change!(val, ir, AllEscape(), changes) ? + @goto add_ssa_escape + end + FieldSets = objinfo.FieldSets + typ = argextype(obj, ir, ir.sptypes, ir.argtypes) + if isa(FieldSets, Bool) + if FieldSets + # the field analysis on this object was already done and unsuccessful, + # nothing can't be done here + else + nfields = fieldcount_noerror(typ) + if nfields !== nothing + FieldSets = IdSet{Any}[IdSet{Any}() for _ in 1:nfields] + @goto add_field + else + # fields aren't known precisely + add_change!(obj, ir, EscapeLattice(objinfo; FieldSets=true), changes) + end + end + else + @label add_field # the field sets have been initialized, now add the alias information + if isa(typ, DataType) + fldval = try_compute_fieldval(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx !== nothing + push!(FieldSets[fidx], val) + else + # when the field can't be known precisely, + # add this alias information to all the field sets conservatively + for FieldSet in FieldSets + push!(FieldSet, val) + end + end + # update `obj`'s escape information with the new field sets + add_change!(obj, ir, EscapeLattice(objinfo; FieldSets), changes) + end + # propagate `obj`'s escape information to `val` as well + add_change!(val, ir, ignore_fieldsets(objinfo), changes) + + # propagate escape information imposed on the return value of this `setfield!` + @label add_ssa_escape + ssainfo = state.ssavalues[pc] + if ssainfo == NotAnalyzed() + ssainfo = NoEscape() end + add_change!(val, ir, ssainfo, changes) end # entries @@ -714,18 +1006,18 @@ __clear_caches!() = (__clear_code_cache!(); __clear_escape_cache!()) function get_name_color(x::EscapeLattice, symbol::Bool = false) getname(x) = string(nameof(x)) - if x == NotAnalyzed() + if ignore_fieldsets(x) == NotAnalyzed() name, color = (getname(NotAnalyzed), '◌'), :plain - elseif x == NoEscape() + elseif ignore_fieldsets(x) == NoEscape() name, color = (getname(NoEscape), '✓'), :green - elseif NoEscape() ⊏ x ⊑ AllReturnEscape() + elseif NoEscape() ⊏ ignore_fieldsets(x) ⊑ AllReturnEscape() pcs = sprint(show, collect(x.ReturnEscape); context=:limit=>true) name1 = string(getname(ReturnEscape), '(', pcs, ')') name = name1, '↑' color = :cyan - elseif NoEscape() ⊏ x ⊑ ThrownEscape() + elseif NoEscape() ⊏ ignore_fieldsets(x) ⊑ ThrownEscape() name, color = (getname(ThrownEscape), '↓'), :yellow - elseif NoEscape() ⊏ x ⊑ GlobalEscape() + elseif NoEscape() ⊏ ignore_fieldsets(x) ⊑ GlobalEscape() name, color = (getname(GlobalEscape), 'G'), :red elseif x == AllEscape() name, color = (getname(AllEscape), 'X'), :red diff --git a/src/disjoint_set.jl b/src/disjoint_set.jl new file mode 100644 index 0000000..1f99685 --- /dev/null +++ b/src/disjoint_set.jl @@ -0,0 +1,141 @@ +# a disjoint set implementation +# adapted from https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +import Base: + length, + eltype, + union!, + push! + +using Base: OneTo + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/test/runtests.jl b/test/runtests.jl index bc21f2a..81b591f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,11 @@ mutable struct MutableCondition cond::Bool end +mutable struct MutableFields{S,T} + field1::S + field2::T +end + @testset "EscapeAnalysis" begin @testset "basics" begin @@ -375,6 +380,31 @@ end end @testset "field analysis" begin + let # escape object => escape its fields + result = analyze_escapes((Any,)) do a + global o = MutableSome{Any}(a) + nothing + end + i = findfirst(==(MutableSome{Any}), result.ir.stmts.type) + @assert !isnothing(i) + @test has_global_escape(result.state.ssavalues[i]) + @test has_global_escape(result.state.arguments[2]) + end + + let # escape object => escape its fields + result = analyze_escapes((Any,)) do a + o0 = MutableSome{Any}(a) + global o = MutableSome(o0) + nothing + end + i0 = findfirst(==(MutableSome{Any}), result.ir.stmts.type) + i1 = findfirst(==(MutableSome{MutableSome{Any}}), result.ir.stmts.type) + @assert !isnothing(i0) && !isnothing(i1) + @test has_global_escape(result.state.ssavalues[i0]) + @test has_global_escape(result.state.ssavalues[i1]) + @test has_global_escape(result.state.arguments[2]) + end + let result = analyze_escapes((String,)) do a # => ReturnEscape o = MutableSome(a) # no need to escape @@ -385,7 +415,7 @@ end r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) @assert !isnothing(i) && !isnothing(r) @test has_return_escape(result.state.arguments[2], r) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) end let @@ -400,10 +430,10 @@ end @assert !isnothing(i) && !isnothing(r) @test has_return_escape(result.state.arguments[2], r) @test has_global_escape(result.state.ssavalues[i]) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) end - let + let # nested objects (not aliased) result = analyze_escapes((String,)) do a # => ReturnEscape o1 = MutableSome(a) # => ReturnEscape o2 = MutableSome(o1) # no need to escape @@ -414,36 +444,91 @@ end r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) @assert !isnothing(i1) && !isnothing(i2) && !isnothing(4) @test has_return_escape(result.state.arguments[2], r) - @test has_return_escape(result.state.ssavalues[i1]) - @test_broken !has_return_escape(result.state.ssavalues[i2]) + @test has_return_escape(result.state.ssavalues[i1], r) + @test !has_return_escape(result.state.ssavalues[i2], r) end - let + let # multiple fields + result = analyze_escapes((String, String)) do a, b # => ReturnEscape, ReturnEscape + obj = MutableFields(a, b) # => NoEscape + fld1 = obj.field1 + fld2 = obj.field2 + return (fld1, fld2) + end + i = findfirst(==(MutableFields{String,String}), result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) # a + @test has_return_escape(result.state.arguments[3], r) # b + @test !has_return_escape(result.state.ssavalues[i], r) + end + + let # tuple result = analyze_escapes((Any,)) do a # => GlobalEscape t = tuple(a) # no need to escape - global tt = t[1] - return nothing + return t[1] end - i = findfirst(t->t<:Tuple, result.ir.stmts.type) # allocation statement - @assert !isnothing(i) - @test has_global_escape(result.state.arguments[2]) - @test_broken !has_global_escape(result.state.ssavalues[i]) + i = findfirst(t->t<:Tuple, result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + + let # TODO inter-procedural conversion + m = Module() + @eval m @noinline getvalue(obj) = obj.value + result = @eval m $analyze_escapes((String,)) do a # => ReturnEscape + obj = $MutableSome(a) # no need to escape + fld = getvalue(obj) + return fld + end + i = findfirst(==(MutableSome{String}), result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + + let # `popfirst!(InvasiveLinkedList{Task})` within this `println` used to cause infinite loop ... + result = analyze_escapes((String,)) do a + println(a) + nothing + end + @test true end end -# demonstrate a simple type level analysis can sometimes compensate the lack of yet unimplemented analysis -@testset "special-casing bitstype" begin - let # lack of field analysis - result = analyze_escapes((Int,)) do a - o = MutableSome(a) # no need to escape - f = getfield(o, :value) - return f +@testset "alias analysis" begin + let result = analyze_escapes((String,)) do a # => ReturnEscape + o1 = MutableSome(a) # => NoEscape + o2 = MutableSome(o1) # => NoEscape + o1′ = getfield(o2, :value) # o1 + a′ = getfield(o1′, :value) # a + return a′ end - i = findfirst(==(MutableSome{Int}), result.ir.stmts.type) # allocation statement - @assert !isnothing(i) - @test has_no_escape(result.state.ssavalues[i]) + i1 = findfirst(==(MutableSome{String}), result.ir.stmts.type) + i2 = findfirst(==(MutableSome{MutableSome{String}}), result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i1) && !isnothing(i2) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i1], r) + @test !has_return_escape(result.state.ssavalues[i2], r) + end + + let result = analyze_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) + @test !has_return_escape(result.state.ssavalues[i], r) end +end +# demonstrate a simple type level analysis can sometimes compensate the lack of yet unimplemented analysis +@testset "special-casing bitstype" begin let # an escaped tuple stmt will not propagate to its Int argument (since Int is of bitstype) result = analyze_escapes((Int, Any, )) do a, b t = tuple(a, b) @@ -526,7 +611,6 @@ end function function_filter(@nospecialize(ft)) ft === typeof(Core.Compiler.widenconst) && return false # `widenconst` is very untyped, ignore ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore - ft === typeof(isbitstype) && return false # `isbitstype` is very untyped, ignore return true end