From 9de71cbbfcc4d20a7ef7554e2200c08b0a60b3bb Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 5 Oct 2021 01:53:30 +0900 Subject: [PATCH] field analysis --- src/EscapeAnalysis.jl | 116 +++++++++++++++++++++++++++++++++++------- test/runtests.jl | 92 ++++++++++++++++++++------------- 2 files changed, 156 insertions(+), 52 deletions(-) diff --git a/src/EscapeAnalysis.jl b/src/EscapeAnalysis.jl index 4fa6ddf..50a04eb 100644 --- a/src/EscapeAnalysis.jl +++ b/src/EscapeAnalysis.jl @@ -180,36 +180,42 @@ struct EscapeLattice ThrownEscape::Bool GlobalEscape::Bool # TODO: ArgEscape::Int + FieldEscapes::Vector{EscapeLattice} end function Base.:(==)(x::EscapeLattice, y::EscapeLattice) return x.Analyzed === y.Analyzed && x.ReturnEscape == y.ReturnEscape && x.ThrownEscape === y.ThrownEscape && - x.GlobalEscape === y.GlobalEscape + x.GlobalEscape === y.GlobalEscape && + x.FieldEscapes == y.FieldEscapes end # lattice constructors # precompute default values in order to eliminate computations at callsites const NO_RETURN = BitSet() const ARGUMENT_RETURN = BitSet(0) -NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false) # not formally part of the lattice -NoEscape() = EscapeLattice(true, NO_RETURN, false, false) -ReturnEscape(returns::BitSet) = EscapeLattice(true, returns, false, false) +const NO_FIELD_ESCAPE = EscapeLattice[] +NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false, NO_FIELD_ESCAPE) # not formally part of the lattice +NoEscape() = EscapeLattice(true, NO_RETURN, false, false, NO_FIELD_ESCAPE) +ReturnEscape(returns::BitSet) = EscapeLattice(true, returns, false, false, NO_FIELD_ESCAPE) ReturnEscape(pc::Int) = ReturnEscape(BitSet(pc)) ArgumentReturnEscape() = ReturnEscape(ARGUMENT_RETURN) -ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false) -GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true) +ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false, NO_FIELD_ESCAPE) +GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true, NO_FIELD_ESCAPE) +FieldEscapes(finfos::Vector{EscapeLattice}) = EscapeLattice(true, NO_RETURN, false, false, finfos) +NoFieldEscapes(info::EscapeLattice) = EscapeLattice(info.Analyzed, info.ReturnEscape, info.ThrownEscape, info.GlobalEscape, NO_FIELD_ESCAPE) let all_return = BitSet(0:1000000) global AllReturnEscape() = ReturnEscape(all_return) # used for `show` - global AllEscape() = EscapeLattice(true, all_return, true, true) + global AllEscape() = EscapeLattice(true, all_return, true, true, NO_FIELD_ESCAPE) end # Convenience names for some ⊑ queries export has_not_analyzed, has_no_escape, + has_no_escape′, has_return_escape, has_thrown_escape, has_global_escape, @@ -217,11 +223,12 @@ export can_elide_finalizer has_not_analyzed(x::EscapeLattice) = x == NotAnalyzed() has_no_escape(x::EscapeLattice) = x ⊑ NoEscape() +has_no_escape′(x::EscapeLattice) = NoFieldEscapes(x) ⊑ NoEscape() has_return_escape(x::EscapeLattice) = !isempty(x.ReturnEscape) has_return_escape(x::EscapeLattice, pc::Int) = pc in x.ReturnEscape has_thrown_escape(x::EscapeLattice) = x.ThrownEscape has_global_escape(x::EscapeLattice) = x.GlobalEscape -has_all_escape(x::EscapeLattice) = AllEscape() == x +has_all_escape(x::EscapeLattice) = AllEscape() ⊑ x """ can_elide_finalizer(x::EscapeLattice, pc::Int) -> Bool @@ -238,6 +245,11 @@ function can_elide_finalizer(x::EscapeLattice, pc::Int) end function ⊑(x::EscapeLattice, y::EscapeLattice) + xf, yf = x.FieldEscapes, y.FieldEscapes + length(xf) ≤ length(yf) || return false + for (x′, y′) in zip(xf, yf) + x′ ⊑ y′ || return false + end if x.Analyzed ≤ y.Analyzed && x.ReturnEscape ⊆ y.ReturnEscape && x.ThrownEscape ≤ y.ThrownEscape && @@ -249,11 +261,30 @@ end ⋤(x::EscapeLattice, y::EscapeLattice) = ⊑(x, y) && !⊑(y, x) function ⊔(x::EscapeLattice, y::EscapeLattice) + xf, yf = x.FieldEscapes, y.FieldEscapes + if isempty(xf) # fast pass + FieldEscapes = yf + elseif isempty(yf) # fast pass + FieldEscapes = xf + else + xn, yn = length(xf), length(yf) + (sf, sn), (lf, ln) = xn ≤ yn ? ((xf, xn), (yf, yn)) : ((yf, yn), (xf, xn)) + FieldEscapes = Vector{EscapeLattice}(undef, ln) + for i in 1:ln + if i > sn + FieldEscape = lf[i] + else + FieldEscape = sf[i] ⊔ lf[i] + end + FieldEscapes[i] = FieldEscape + end + end return EscapeLattice( x.Analyzed | y.Analyzed, x.ReturnEscape ∪ y.ReturnEscape, x.ThrownEscape | y.ThrownEscape, x.GlobalEscape | y.GlobalEscape, + FieldEscapes, ) end @@ -345,7 +376,14 @@ function find_escapes(ir::IRCode, nargs::Int) info = NoEscape() add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not end - add_changes!(stmt.args[2:end], ir, info, changes) + args = stmt.args[2:end] + add_changes!(args, ir, NoFieldEscapes(info), changes) + n = min(length(args), length(info.FieldEscapes)) + for i in 1:n + arg = args[i] + finfo = info.FieldEscapes[i] + add_change!(arg, ir, finfo, changes) + end elseif head === :splatnew info = state.ssavalues[pc] if info == NotAnalyzed() @@ -532,7 +570,7 @@ end # context of the caller frame using the escape information imposed on the return value (`retinfo`) function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice) ar = arginfo.ReturnEscape - newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape) + newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, arginfo.FieldEscapes) if ar == ARGUMENT_RETURN # if this is simply passed as the call argument, we can discard the `ReturnEscape` # information and just propagate the other escape information @@ -593,21 +631,60 @@ function escape_builtin!(::typeof(tuple), args::Vector{Any}, pc::Int, state::Esc if info == NotAnalyzed() info = NoEscape() end - add_changes!(args[2:end], ir, info, changes) + args = args[2:end] + add_changes!(args, ir, info, changes) + n = min(length(args), length(info.FieldEscapes)) + for i in 1:n + arg = args[i] + finfo = info.FieldEscapes[i] + add_change!(arg, ir, finfo, changes) + end return true end -# TODO don't propagate escape information to the 1st argument, but propagate information to aliased field +function try_compute_fieldidx_args(ir::IRCode, args::Vector{Any}, typ::DataType) + field = args[3] + if isa(field, QuoteNode) + field = field.value + elseif isa(field, Int) + # try to resolve other constants, e.g. global reference + else + field = argextype(field, ir, ir.sptypes, ir.argtypes) + if isa(field, Const) + field = field.val + else + return nothing + end + end + isa(field, Union{Int, Symbol}) || return nothing + return CC.try_compute_fieldidx(typ, field) +end + function escape_builtin!(::typeof(getfield), args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes) + # only propagate info when the field itself is non-bitstype + isbitstype(widenconst(ir.stmts.type[pc])) && return true info = state.ssavalues[pc] if info == NotAnalyzed() info = NoEscape() end - # only propagate info when the field itself is non-bitstype - if !isbitstype(widenconst(ir.stmts.type[pc])) - add_changes!(args[2:end], ir, info, changes) + if length(args) ≥ 2 + obj = args[2] + objt = widenconst(argextype(obj, ir, ir.sptypes, ir.argtypes)) + if isa(objt, DataType) + idx = try_compute_fieldidx_args(ir, args, objt) + if !isnothing(idx) + n = CC.fieldcount_noerror(objt) + if !isnothing(n) + newinfo = FieldEscapes(EscapeLattice[ + i == idx ? info : NoEscape() for i in 1:n]) + add_change!(obj, ir, newinfo, changes) + return true + end + end + end + add_change!(obj, ir, info, changes) end - return true + return true # may throw, but we will handle that later end # entries @@ -680,8 +757,10 @@ end # in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.) __clear_caches!() = (__clear_code_cache!(); __clear_escape_cache!()) -function get_name_color(x::EscapeLattice, symbol::Bool = false) +function get_name_color(x::EscapeLattice, short::Bool = false) getname(x) = string(nameof(x)) + FieldEscapes = x.FieldEscapes + x = NoFieldEscapes(x) if x == NotAnalyzed() name, color = (getname(NotAnalyzed), '◌'), :plain elseif x == NoEscape() @@ -698,7 +777,8 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false) else name, color = (nothing, '*'), :red end - return (symbol ? last(name) : first(name), color) + s = string((short ? last : first)(name), (isempty(FieldEscapes) ? ' ' : '′')) + return s, color end function Base.show(io::IO, x::EscapeLattice) diff --git a/test/runtests.jl b/test/runtests.jl index 5b9318d..a3554f2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,10 +3,16 @@ using EscapeAnalysis, InteractiveUtils, Test, JET mutable struct MutableSome{T} value::T end + mutable struct MutableCondition cond::Bool end +mutable struct MutableFields{S,T} + field1::S + field2::T +end + @testset "EscapeAnalysis" begin @testset "basics" begin @@ -205,7 +211,7 @@ end inds = findall(==(MutableSome{T}), result.ir.stmts.type) # find allocation statement @assert !isempty(inds) for i in inds - @test has_no_escape(result.state.ssavalues[i]) + @test has_no_escape′(result.state.ssavalues[i]) end end @@ -229,40 +235,54 @@ end return aaa end i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement - @assert !isnothing(i) - @test has_return_escape(result.state.ssavalues[i]) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.ssavalues[i], r) end # should propagate escape information imposed on return value to the aliased call argument - @eval m @noinline function f_return_escape(a) - println("hi") # prevent inlining - return a - end + @eval m @noinline f_return_escape(a) = a let result = @eval m $analyze_escapes() do obj = Ref("foo") # should be "return escape" - ret = f_return_escape(obj) + ret = @noinline f_return_escape(obj) return ret # alias of `obj` end i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement - @assert !isnothing(i) - @test has_return_escape(result.state.ssavalues[i]) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.ssavalues[i], r) end - @eval m @noinline function f_no_return_escape(a) - println("hi") # prevent inlining - return "hi" - end + @eval m @noinline f_no_return_escape(a) = identity("hi") let result = @eval m $analyze_escapes() do obj = Ref("foo") # better to not be "return escape" - ret = f_no_return_escape(obj) + ret = @noinline f_no_return_escape(obj) return ret # must not alias to `obj` end i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement - @assert !isnothing(i) - @test !has_return_escape(result.state.ssavalues[i]) - end + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test !has_return_escape(result.state.ssavalues[i], r) + end + + # # FIXME! `println` causes infinite loop ... + # @eval m @noinline f_no_return_escape2(a) = begin + # println("hi") + # identity("hi") + # end + # let + # result = @eval m $analyze_escapes() do + # obj = Ref("foo") # better to not be "return escape" + # ret = @noinline f_no_return_escape2(obj) + # return ret # must not alias to `obj` + # end + # i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + # r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + # @assert !isnothing(i) && !isnothing(r) + # @test !has_return_escape(result.state.ssavalues[i], r) + # end end @testset "builtins" begin @@ -376,7 +396,7 @@ end r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) @assert !isnothing(i) && !isnothing(r) @test has_return_escape(result.state.arguments[2], r) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) end let @@ -391,7 +411,7 @@ end @assert !isnothing(i) && !isnothing(r) @test has_return_escape(result.state.arguments[2], r) @test has_global_escape(result.state.ssavalues[i]) - @test_broken !has_return_escape(result.state.ssavalues[i], r) + @test !has_return_escape(result.state.ssavalues[i], r) end let @@ -405,8 +425,23 @@ end r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) @assert !isnothing(i1) && !isnothing(i2) && !isnothing(4) @test has_return_escape(result.state.arguments[2], r) - @test has_return_escape(result.state.ssavalues[i1]) - @test_broken !has_return_escape(result.state.ssavalues[i2]) + @test has_return_escape(result.state.ssavalues[i1], r) + @test !has_return_escape(result.state.ssavalues[i2], r) + end + + let # multiple fields + result = analyze_escapes((String, String)) do a, b # => ReturnEscape, ReturnEscape + obj = MutableFields(a, b) # => NoEscape + fld1 = obj.field1 + fld2 = obj.field2 + return (fld1, fld2) + end + i = findfirst(==(MutableFields{String,String}), result.ir.stmts.type) + r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst) + @assert !isnothing(i) && !isnothing(r) + @test has_return_escape(result.state.arguments[2], r) # a + @test has_return_escape(result.state.arguments[3], r) # b + @test !has_return_escape(result.state.ssavalues[i], r) end let @@ -418,23 +453,12 @@ end i = findfirst(t->t<:Tuple, result.ir.stmts.type) # allocation statement @assert !isnothing(i) @test has_global_escape(result.state.arguments[2]) - @test_broken !has_global_escape(result.state.ssavalues[i]) + @test !has_global_escape(result.state.ssavalues[i]) end end # demonstrate a simple type level analysis can sometimes compensate the lack of yet unimplemented analysis @testset "special-casing bitstype" begin - let # lack of field analysis - result = analyze_escapes((Int,)) do a - o = MutableSome(a) # no need to escape - f = getfield(o, :value) - return f - end - i = findfirst(==(MutableSome{Int}), result.ir.stmts.type) # allocation statement - @assert !isnothing(i) - @test has_no_escape(result.state.ssavalues[i]) - end - let # an escaped tuple stmt will not propagate to its Int argument (since Int is of bitstype) result = analyze_escapes((Int, Any, )) do a, b t = tuple(a, b)