Skip to content

Commit

Permalink
field analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Oct 4, 2021
1 parent a127a21 commit 9de71cb
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 52 deletions.
116 changes: 98 additions & 18 deletions src/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,48 +180,55 @@ struct EscapeLattice
ThrownEscape::Bool
GlobalEscape::Bool
# TODO: ArgEscape::Int
FieldEscapes::Vector{EscapeLattice}
end

function Base.:(==)(x::EscapeLattice, y::EscapeLattice)
return x.Analyzed === y.Analyzed &&
x.ReturnEscape == y.ReturnEscape &&
x.ThrownEscape === y.ThrownEscape &&
x.GlobalEscape === y.GlobalEscape
x.GlobalEscape === y.GlobalEscape &&
x.FieldEscapes == y.FieldEscapes
end

# lattice constructors
# precompute default values in order to eliminate computations at callsites
const NO_RETURN = BitSet()
const ARGUMENT_RETURN = BitSet(0)
NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false) # not formally part of the lattice
NoEscape() = EscapeLattice(true, NO_RETURN, false, false)
ReturnEscape(returns::BitSet) = EscapeLattice(true, returns, false, false)
const NO_FIELD_ESCAPE = EscapeLattice[]
NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false, NO_FIELD_ESCAPE) # not formally part of the lattice
NoEscape() = EscapeLattice(true, NO_RETURN, false, false, NO_FIELD_ESCAPE)
ReturnEscape(returns::BitSet) = EscapeLattice(true, returns, false, false, NO_FIELD_ESCAPE)
ReturnEscape(pc::Int) = ReturnEscape(BitSet(pc))
ArgumentReturnEscape() = ReturnEscape(ARGUMENT_RETURN)
ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false)
GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true)
ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false, NO_FIELD_ESCAPE)
GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true, NO_FIELD_ESCAPE)
FieldEscapes(finfos::Vector{EscapeLattice}) = EscapeLattice(true, NO_RETURN, false, false, finfos)
NoFieldEscapes(info::EscapeLattice) = EscapeLattice(info.Analyzed, info.ReturnEscape, info.ThrownEscape, info.GlobalEscape, NO_FIELD_ESCAPE)
let
all_return = BitSet(0:1000000)
global AllReturnEscape() = ReturnEscape(all_return) # used for `show`
global AllEscape() = EscapeLattice(true, all_return, true, true)
global AllEscape() = EscapeLattice(true, all_return, true, true, NO_FIELD_ESCAPE)
end

# Convenience names for some ⊑ queries
export
has_not_analyzed,
has_no_escape,
has_no_escape′,
has_return_escape,
has_thrown_escape,
has_global_escape,
has_all_escape,
can_elide_finalizer
has_not_analyzed(x::EscapeLattice) = x == NotAnalyzed()
has_no_escape(x::EscapeLattice) = x NoEscape()
has_no_escape′(x::EscapeLattice) = NoFieldEscapes(x) NoEscape()
has_return_escape(x::EscapeLattice) = !isempty(x.ReturnEscape)
has_return_escape(x::EscapeLattice, pc::Int) = pc in x.ReturnEscape
has_thrown_escape(x::EscapeLattice) = x.ThrownEscape
has_global_escape(x::EscapeLattice) = x.GlobalEscape
has_all_escape(x::EscapeLattice) = AllEscape() == x
has_all_escape(x::EscapeLattice) = AllEscape() x

"""
can_elide_finalizer(x::EscapeLattice, pc::Int) -> Bool
Expand All @@ -238,6 +245,11 @@ function can_elide_finalizer(x::EscapeLattice, pc::Int)
end

function (x::EscapeLattice, y::EscapeLattice)
xf, yf = x.FieldEscapes, y.FieldEscapes
length(xf) length(yf) || return false
for (x′, y′) in zip(xf, yf)
x′ y′ || return false
end
if x.Analyzed y.Analyzed &&
x.ReturnEscape y.ReturnEscape &&
x.ThrownEscape y.ThrownEscape &&
Expand All @@ -249,11 +261,30 @@ end
(x::EscapeLattice, y::EscapeLattice) = (x, y) && !(y, x)

function (x::EscapeLattice, y::EscapeLattice)
xf, yf = x.FieldEscapes, y.FieldEscapes
if isempty(xf) # fast pass
FieldEscapes = yf
elseif isempty(yf) # fast pass
FieldEscapes = xf
else
xn, yn = length(xf), length(yf)
(sf, sn), (lf, ln) = xn yn ? ((xf, xn), (yf, yn)) : ((yf, yn), (xf, xn))
FieldEscapes = Vector{EscapeLattice}(undef, ln)
for i in 1:ln
if i > sn
FieldEscape = lf[i]
else
FieldEscape = sf[i] lf[i]
end
FieldEscapes[i] = FieldEscape
end
end
return EscapeLattice(
x.Analyzed | y.Analyzed,
x.ReturnEscape y.ReturnEscape,
x.ThrownEscape | y.ThrownEscape,
x.GlobalEscape | y.GlobalEscape,
FieldEscapes,
)
end

Expand Down Expand Up @@ -345,7 +376,14 @@ function find_escapes(ir::IRCode, nargs::Int)
info = NoEscape()
add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not
end
add_changes!(stmt.args[2:end], ir, info, changes)
args = stmt.args[2:end]
add_changes!(args, ir, NoFieldEscapes(info), changes)
n = min(length(args), length(info.FieldEscapes))
for i in 1:n
arg = args[i]
finfo = info.FieldEscapes[i]
add_change!(arg, ir, finfo, changes)
end
elseif head === :splatnew
info = state.ssavalues[pc]
if info == NotAnalyzed()
Expand Down Expand Up @@ -532,7 +570,7 @@ end
# context of the caller frame using the escape information imposed on the return value (`retinfo`)
function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice)
ar = arginfo.ReturnEscape
newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape)
newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, arginfo.FieldEscapes)
if ar == ARGUMENT_RETURN
# if this is simply passed as the call argument, we can discard the `ReturnEscape`
# information and just propagate the other escape information
Expand Down Expand Up @@ -593,21 +631,60 @@ function escape_builtin!(::typeof(tuple), args::Vector{Any}, pc::Int, state::Esc
if info == NotAnalyzed()
info = NoEscape()
end
add_changes!(args[2:end], ir, info, changes)
args = args[2:end]
add_changes!(args, ir, info, changes)
n = min(length(args), length(info.FieldEscapes))
for i in 1:n
arg = args[i]
finfo = info.FieldEscapes[i]
add_change!(arg, ir, finfo, changes)
end
return true
end

# TODO don't propagate escape information to the 1st argument, but propagate information to aliased field
function try_compute_fieldidx_args(ir::IRCode, args::Vector{Any}, typ::DataType)
field = args[3]
if isa(field, QuoteNode)
field = field.value
elseif isa(field, Int)
# try to resolve other constants, e.g. global reference
else
field = argextype(field, ir, ir.sptypes, ir.argtypes)
if isa(field, Const)
field = field.val
else
return nothing
end
end
isa(field, Union{Int, Symbol}) || return nothing
return CC.try_compute_fieldidx(typ, field)
end

function escape_builtin!(::typeof(getfield), args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes)
# only propagate info when the field itself is non-bitstype
isbitstype(widenconst(ir.stmts.type[pc])) && return true
info = state.ssavalues[pc]
if info == NotAnalyzed()
info = NoEscape()
end
# only propagate info when the field itself is non-bitstype
if !isbitstype(widenconst(ir.stmts.type[pc]))
add_changes!(args[2:end], ir, info, changes)
if length(args) 2
obj = args[2]
objt = widenconst(argextype(obj, ir, ir.sptypes, ir.argtypes))
if isa(objt, DataType)
idx = try_compute_fieldidx_args(ir, args, objt)
if !isnothing(idx)
n = CC.fieldcount_noerror(objt)
if !isnothing(n)
newinfo = FieldEscapes(EscapeLattice[
i == idx ? info : NoEscape() for i in 1:n])
add_change!(obj, ir, newinfo, changes)
return true
end
end
end
add_change!(obj, ir, info, changes)
end
return true
return true # may throw, but we will handle that later
end

# entries
Expand Down Expand Up @@ -680,8 +757,10 @@ end
# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.)
__clear_caches!() = (__clear_code_cache!(); __clear_escape_cache!())

function get_name_color(x::EscapeLattice, symbol::Bool = false)
function get_name_color(x::EscapeLattice, short::Bool = false)
getname(x) = string(nameof(x))
FieldEscapes = x.FieldEscapes
x = NoFieldEscapes(x)
if x == NotAnalyzed()
name, color = (getname(NotAnalyzed), ''), :plain
elseif x == NoEscape()
Expand All @@ -698,7 +777,8 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false)
else
name, color = (nothing, '*'), :red
end
return (symbol ? last(name) : first(name), color)
s = string((short ? last : first)(name), (isempty(FieldEscapes) ? ' ' : ''))
return s, color
end

function Base.show(io::IO, x::EscapeLattice)
Expand Down
92 changes: 58 additions & 34 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ using EscapeAnalysis, InteractiveUtils, Test, JET
mutable struct MutableSome{T}
value::T
end

mutable struct MutableCondition
cond::Bool
end

mutable struct MutableFields{S,T}
field1::S
field2::T
end

@testset "EscapeAnalysis" begin

@testset "basics" begin
Expand Down Expand Up @@ -205,7 +211,7 @@ end
inds = findall(==(MutableSome{T}), result.ir.stmts.type) # find allocation statement
@assert !isempty(inds)
for i in inds
@test has_no_escape(result.state.ssavalues[i])
@test has_no_escape(result.state.ssavalues[i])
end
end

Expand All @@ -229,40 +235,54 @@ end
return aaa
end
i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement
@assert !isnothing(i)
@test has_return_escape(result.state.ssavalues[i])
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.ssavalues[i], r)
end

# should propagate escape information imposed on return value to the aliased call argument
@eval m @noinline function f_return_escape(a)
println("hi") # prevent inlining
return a
end
@eval m @noinline f_return_escape(a) = a
let
result = @eval m $analyze_escapes() do
obj = Ref("foo") # should be "return escape"
ret = f_return_escape(obj)
ret = @noinline f_return_escape(obj)
return ret # alias of `obj`
end
i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement
@assert !isnothing(i)
@test has_return_escape(result.state.ssavalues[i])
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.ssavalues[i], r)
end

@eval m @noinline function f_no_return_escape(a)
println("hi") # prevent inlining
return "hi"
end
@eval m @noinline f_no_return_escape(a) = identity("hi")
let
result = @eval m $analyze_escapes() do
obj = Ref("foo") # better to not be "return escape"
ret = f_no_return_escape(obj)
ret = @noinline f_no_return_escape(obj)
return ret # must not alias to `obj`
end
i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement
@assert !isnothing(i)
@test !has_return_escape(result.state.ssavalues[i])
end
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test !has_return_escape(result.state.ssavalues[i], r)
end

# # FIXME! `println` causes infinite loop ...
# @eval m @noinline f_no_return_escape2(a) = begin
# println("hi")
# identity("hi")
# end
# let
# result = @eval m $analyze_escapes() do
# obj = Ref("foo") # better to not be "return escape"
# ret = @noinline f_no_return_escape2(obj)
# return ret # must not alias to `obj`
# end
# i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement
# r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
# @assert !isnothing(i) && !isnothing(r)
# @test !has_return_escape(result.state.ssavalues[i], r)
# end
end

@testset "builtins" begin
Expand Down Expand Up @@ -376,7 +396,7 @@ end
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.arguments[2], r)
@test_broken !has_return_escape(result.state.ssavalues[i], r)
@test !has_return_escape(result.state.ssavalues[i], r)
end

let
Expand All @@ -391,7 +411,7 @@ end
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.arguments[2], r)
@test has_global_escape(result.state.ssavalues[i])
@test_broken !has_return_escape(result.state.ssavalues[i], r)
@test !has_return_escape(result.state.ssavalues[i], r)
end

let
Expand All @@ -405,8 +425,23 @@ end
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i1) && !isnothing(i2) && !isnothing(4)
@test has_return_escape(result.state.arguments[2], r)
@test has_return_escape(result.state.ssavalues[i1])
@test_broken !has_return_escape(result.state.ssavalues[i2])
@test has_return_escape(result.state.ssavalues[i1], r)
@test !has_return_escape(result.state.ssavalues[i2], r)
end

let # multiple fields
result = analyze_escapes((String, String)) do a, b # => ReturnEscape, ReturnEscape
obj = MutableFields(a, b) # => NoEscape
fld1 = obj.field1
fld2 = obj.field2
return (fld1, fld2)
end
i = findfirst(==(MutableFields{String,String}), result.ir.stmts.type)
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.arguments[2], r) # a
@test has_return_escape(result.state.arguments[3], r) # b
@test !has_return_escape(result.state.ssavalues[i], r)
end

let
Expand All @@ -418,23 +453,12 @@ end
i = findfirst(t->t<:Tuple, result.ir.stmts.type) # allocation statement
@assert !isnothing(i)
@test has_global_escape(result.state.arguments[2])
@test_broken !has_global_escape(result.state.ssavalues[i])
@test !has_global_escape(result.state.ssavalues[i])
end
end

# demonstrate a simple type level analysis can sometimes compensate the lack of yet unimplemented analysis
@testset "special-casing bitstype" begin
let # lack of field analysis
result = analyze_escapes((Int,)) do a
o = MutableSome(a) # no need to escape
f = getfield(o, :value)
return f
end
i = findfirst(==(MutableSome{Int}), result.ir.stmts.type) # allocation statement
@assert !isnothing(i)
@test has_no_escape(result.state.ssavalues[i])
end

let # an escaped tuple stmt will not propagate to its Int argument (since Int is of bitstype)
result = analyze_escapes((Int, Any, )) do a, b
t = tuple(a, b)
Expand Down

0 comments on commit 9de71cb

Please sign in to comment.