Skip to content

Commit

Permalink
finite lattice height
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Oct 5, 2021
1 parent 122c600 commit fe8eb05
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 73 deletions.
144 changes: 90 additions & 54 deletions src/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ end
# analysis
# ========

struct FieldLattice
ReturnEscape::BitSet
ThrownEscape::Bool
GlobalEscape::Bool
end

"""
x::EscapeLattice
Expand Down Expand Up @@ -181,31 +187,41 @@ struct EscapeLattice
ThrownEscape::Bool
GlobalEscape::Bool
# TODO: ArgEscape::Int
FieldsEscape::Vector{EscapeLattice}
FieldEscapes::Union{Nothing,Vector{FieldLattice}}
end

EscapeLattice(x::FieldLattice) = EscapeLattice(true, x.ReturnEscape, x.ThrownEscape, x.GlobalEscape, nothing)
FieldLattice(x::EscapeLattice) = FieldLattice(x.ReturnEscape, x.ThrownEscape, x.GlobalEscape)

# we need to make sure this `==` operator corresponds to lattice equality rather than object equality,
# otherwise `propagate_changes` can't detect the convergence
function Base.:(==)(x::EscapeLattice, y::EscapeLattice)
return x.Analyzed === y.Analyzed &&
x.ReturnEscape == y.ReturnEscape &&
x.ThrownEscape === y.ThrownEscape &&
x.GlobalEscape === y.GlobalEscape &&
x.FieldsEscape == y.FieldsEscape
x.FieldEscapes == y.FieldEscapes
end
function Base.:(==)(x::FieldLattice, y::FieldLattice)
return x.ReturnEscape == y.ReturnEscape &&
x.ThrownEscape === y.ThrownEscape &&
x.GlobalEscape === y.GlobalEscape
end

# lattice constructors
# precompute default values in order to eliminate computations at callsites
const NO_RETURN = BitSet()
const ARGUMENT_RETURN = BitSet(0)
const NO_FIELD_ESCAPE = EscapeLattice[]
const NO_FIELD_ESCAPE = nothing # EscapeLattice[]
NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false, NO_FIELD_ESCAPE) # not formally part of the lattice
NoEscape() = EscapeLattice(true, NO_RETURN, false, false, NO_FIELD_ESCAPE)
ReturnEscape(pcs::BitSet) = EscapeLattice(true, pcs, false, false, NO_FIELD_ESCAPE)
ReturnEscape(pc::Int) = ReturnEscape(BitSet(pc))
ArgumentReturnEscape() = ReturnEscape(ARGUMENT_RETURN)
ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false, NO_FIELD_ESCAPE)
GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true, NO_FIELD_ESCAPE)
FieldsEscape(finfos::Vector{EscapeLattice}) = EscapeLattice(true, NO_RETURN, false, false, finfos)
NoFieldsEscape(info::EscapeLattice) = EscapeLattice(info.Analyzed, info.ReturnEscape, info.ThrownEscape, info.GlobalEscape, NO_FIELD_ESCAPE)
FieldEscapes(finfos::Vector{FieldLattice}) = EscapeLattice(true, NO_RETURN, false, false, finfos)
NoFieldEscapes(info::EscapeLattice) = EscapeLattice(info.Analyzed, info.ReturnEscape, info.ThrownEscape, info.GlobalEscape, NO_FIELD_ESCAPE)
let
all_return = BitSet(0:100_000)
global AllReturnEscape() = ReturnEscape(all_return) # used for `show`
Expand All @@ -216,15 +232,13 @@ end
export
has_not_analyzed,
has_no_escape,
has_no_escape′,
has_return_escape,
has_thrown_escape,
has_global_escape,
has_all_escape,
can_elide_finalizer
has_not_analyzed(x::EscapeLattice) = x == NotAnalyzed()
has_no_escape(x::EscapeLattice) = x NoEscape()
has_no_escape′(x::EscapeLattice) = NoFieldsEscape(x) NoEscape()
has_return_escape(x::EscapeLattice) = !isempty(x.ReturnEscape)
has_return_escape(x::EscapeLattice, pc::Int) = pc in x.ReturnEscape
has_thrown_escape(x::EscapeLattice) = x.ThrownEscape
Expand All @@ -245,12 +259,8 @@ function can_elide_finalizer(x::EscapeLattice, pc::Int)
return pc x.ReturnEscape
end

# NOTE this partial order doens't take `FieldEscapes` into account at all
function (x::EscapeLattice, y::EscapeLattice)
xf, yf = x.FieldsEscape, y.FieldsEscape
length(xf) length(yf) || return false
for (x′, y′) in zip(xf, yf)
x′ y′ || return false
end
if x.Analyzed y.Analyzed &&
x.ReturnEscape y.ReturnEscape &&
x.ThrownEscape y.ThrownEscape &&
Expand All @@ -262,43 +272,44 @@ end
(x::EscapeLattice, y::EscapeLattice) = (x, y) && !(y, x)

function (x::EscapeLattice, y::EscapeLattice)
xf, yf = x.FieldsEscape, y.FieldsEscape
if isempty(xf) # fast pass
FieldsEscape = yf
elseif isempty(yf) # fast pass
FieldsEscape = xf
xf, yf = x.FieldEscapes, y.FieldEscapes
if isnothing(xf)
FieldEscapes = yf
elseif isnothing(yf)
FieldEscapes = xf
else
xn, yn = length(xf), length(yf)
xfn, yfn = (xf, xn), (yf, yn)
(sf, sn), (lf, ln) = xn yn ? (xfn, yfn) : (yfn, xfn)
FieldsEscape = Vector{EscapeLattice}(undef, ln)
for i in 1:ln
if i > sn
FieldEscape = lf[i]
else
FieldEscape = sf[i] lf[i]
end
FieldsEscape[i] = FieldEscape
n = min(length(xf), length(yf))
FieldEscapes = Vector{FieldLattice}(undef, n)
for i in 1:n
FieldEscapes[i] = xf[i] yf[i]
end
end
return EscapeLattice(
x.Analyzed | y.Analyzed,
x.ReturnEscape y.ReturnEscape,
x.ThrownEscape | y.ThrownEscape,
x.GlobalEscape | y.GlobalEscape,
FieldsEscape,
FieldEscapes,
)
end

function (x::EscapeLattice, y::EscapeLattice)
return EscapeLattice(
x.Analyzed & y.Analyzed,
x.ReturnEscape y.ReturnEscape,
x.ThrownEscape & y.ThrownEscape,
x.GlobalEscape & y.GlobalEscape,
function (x::FieldLattice, y::FieldLattice)
return FieldLattice(
x.ReturnEscape y.ReturnEscape,
x.ThrownEscape | y.ThrownEscape,
x.GlobalEscape | y.GlobalEscape,
)
end

# NOTE unmaintained now
# function ⊓(x::EscapeLattice, y::EscapeLattice)
# return EscapeLattice(
# x.Analyzed & y.Analyzed,
# x.ReturnEscape ∩ y.ReturnEscape,
# x.ThrownEscape & y.ThrownEscape,
# x.GlobalEscape & y.GlobalEscape,
# )
# end

# TODO setup a more effient struct for cache
# which can discard escape information on SSS values and arguments that don't join dispatch signature

Expand Down Expand Up @@ -350,10 +361,20 @@ function find_escapes(ir::IRCode, nargs::Int)
state = EscapeState(length(ir.argtypes), nargs, nstmts)
changes = Changes() # stashes changes that happen at current statement

local assertion_counter = 0
while true
local anyupdate = false

for pc in nstmts:-1:1
if (assertion_counter += 1) > 10nstmts
Core.eval(Main, quote
ir = $ir
nargs = $nargs
state = $state
end)
throw((assertion_counter, nstmts))
end

stmt = stmts.inst[pc]

# we escape statements with the `ThrownEscape` property using the effect-freeness
Expand All @@ -379,12 +400,15 @@ function find_escapes(ir::IRCode, nargs::Int)
add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not
end
args = stmt.args[2:end]
add_changes!(args, ir, NoFieldsEscape(info), changes)
n = min(length(args), length(info.FieldsEscape))
for i in 1:n
arg = args[i]
finfo = info.FieldsEscape[i]
add_change!(arg, ir, finfo, changes)
add_changes!(args, ir, NoFieldEscapes(info), changes)
finfos = info.FieldEscapes
if !isnothing(finfos)
n = min(length(args), length(finfos))
for i in 1:n
arg = args[i]
finfo = EscapeLattice(finfos[i])
add_change!(arg, ir, finfo, changes)
end
end
elseif head === :splatnew
info = state.ssavalues[pc]
Expand Down Expand Up @@ -572,13 +596,18 @@ end
# context of the caller frame using the escape information imposed on the return value (`retinfo`)
function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice)
ar = arginfo.ReturnEscape
newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, arginfo.FieldsEscape)
fieldsinfo = arginfo.FieldEscapes
newfieldsinfo = isnothing(fieldsinfo) ? nothing : FieldLattice[
FieldLattice(from_interprocedural(EscapeLattice(finfo), retinfo)) for finfo in fieldsinfo]
newarginfo = EscapeLattice(
true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, newfieldsinfo)
if ar == ARGUMENT_RETURN
# if this is simply passed as the call argument, we can discard the `ReturnEscape`
# information and just propagate the other escape information
return newarginfo
else
# if this can be a return value, we have to merge it with the escape information
# at the current SSA value
return newarginfo retinfo
end
end
Expand Down Expand Up @@ -635,11 +664,14 @@ function escape_builtin!(::typeof(tuple), args::Vector{Any}, pc::Int, state::Esc
end
args = args[2:end]
add_changes!(args, ir, info, changes)
n = min(length(args), length(info.FieldsEscape))
for i in 1:n
arg = args[i]
finfo = info.FieldsEscape[i]
add_change!(arg, ir, finfo, changes)
fieldsinfo = info.FieldEscapes
if !isnothing(fieldsinfo)
n = min(length(args), length(fieldsinfo))
for i in 1:n
arg = args[i]
finfo = EscapeLattice(fieldsinfo[i])
add_change!(arg, ir, finfo, changes)
end
end
return true
end
Expand Down Expand Up @@ -677,8 +709,8 @@ function escape_builtin!(::typeof(getfield), args::Vector{Any}, pc::Int, state::
if !isnothing(idx)
n = CC.fieldcount_noerror(objt)
if !isnothing(n)
newinfo = FieldsEscape(EscapeLattice[
i == idx ? info : NoEscape() for i in 1:n])
newinfo = FieldEscapes(FieldLattice[
FieldLattice(i == idx ? info : NoEscape()) for i in 1:n])
add_change!(obj, ir, newinfo, changes)
return true
end
Expand Down Expand Up @@ -762,8 +794,7 @@ __clear_caches!() = (__clear_code_cache!(); __clear_escape_cache!())

function get_name_color(x::EscapeLattice, short::Bool = false)
getname(x) = string(nameof(x))
FieldsEscape = x.FieldsEscape
x = NoFieldsEscape(x)
FieldEscapes = x.FieldEscapes
if x == NotAnalyzed()
name, color = (getname(NotAnalyzed), ''), :plain
elseif x == NoEscape()
Expand All @@ -782,8 +813,13 @@ function get_name_color(x::EscapeLattice, short::Bool = false)
else
name, color = (nothing, '*'), :red
end
s = string((short ? last : first)(name), (isempty(FieldsEscape) ? ' ' : ''))
return s, color
if short
s = last(name)
return string(s, isnothing(FieldEscapes) ? ' ' : ''), color
else
s = first(name)
return (isnothing(s) ? s : string(s, isnothing(FieldEscapes) ? "" : '')), color
end
end

function Base.show(io::IO, x::EscapeLattice)
Expand Down
54 changes: 35 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ end
inds = findall(==(MutableSome{T}), result.ir.stmts.type) # find allocation statement
@assert !isempty(inds)
for i in inds
@test has_no_escape(result.state.ssavalues[i])
@test has_no_escape(result.state.ssavalues[i])
end
end

Expand Down Expand Up @@ -266,23 +266,6 @@ end
@assert !isnothing(i) && !isnothing(r)
@test !has_return_escape(result.state.ssavalues[i], r)
end

# # FIXME! `println` causes infinite loop ...
# @eval m @noinline f_no_return_escape2(a) = begin
# println("hi")
# identity("hi")
# end
# let
# result = @eval m $analyze_escapes() do
# obj = Ref("foo") # better to not be "return escape"
# ret = @noinline f_no_return_escape2(obj)
# return ret # must not alias to `obj`
# end
# i = findfirst(==(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement
# r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
# @assert !isnothing(i) && !isnothing(r)
# @test !has_return_escape(result.state.ssavalues[i], r)
# end
end

@testset "builtins" begin
Expand Down Expand Up @@ -414,7 +397,7 @@ end
@test !has_return_escape(result.state.ssavalues[i], r)
end

let
let # nested unwrap
result = analyze_escapes((String,)) do a # => ReturnEscape
o1 = MutableSome(a) # => ReturnEscape
o2 = MutableSome(o1) # no need to escape
Expand All @@ -429,6 +412,16 @@ end
@test !has_return_escape(result.state.ssavalues[i2], r)
end

let # TODO nested wrap (NOTE: we're interested in the value of field)
result = analyze_escapes((String,)) do a # => ReturnEscape
o1 = MutableSome(a) # => ReturnEscape
o2 = MutableSome(o1) # => NoEscape
o1′ = getfield(o2, :value) # => FieldEscapes(ReturnEscape)
a′ = getfield(o1′, :value) # => ReturnEscape
return a′
end
end

let # multiple fields
result = analyze_escapes((String, String)) do a, b # => ReturnEscape, ReturnEscape
obj = MutableFields(a, b) # => NoEscape
Expand All @@ -455,6 +448,29 @@ end
@test has_global_escape(result.state.arguments[2])
@test !has_global_escape(result.state.ssavalues[i])
end

let # inter-procedural conversion
m = Module()
@eval m @noinline getvalue(obj) = obj.value
result = @eval m $analyze_escapes((String,)) do a # => ReturnEscape
obj = $MutableSome(a) # no need to escape
fld = getvalue(obj)
return fld
end
i = findfirst(==(MutableSome{String}), result.ir.stmts.type)
r = findfirst(x->isa(x, Core.ReturnNode), result.ir.stmts.inst)
@assert !isnothing(i) && !isnothing(r)
@test has_return_escape(result.state.arguments[2], r)
@test !has_return_escape(result.state.ssavalues[i], r)
end

let # `popfirst!(InvasiveLinkedList{Task})` within this `println` used to cause infinite loop ...
result = analyze_escapes((String,)) do a
println(a)
nothing
end
@test true
end
end

# demonstrate a simple type level analysis can sometimes compensate the lack of yet unimplemented analysis
Expand Down

0 comments on commit fe8eb05

Please sign in to comment.