Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: field analysis #43

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 142 additions & 25 deletions src/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ end
# analysis
# ========

struct FieldLattice
ReturnEscape::BitSet
ThrownEscape::Bool
GlobalEscape::Bool
end

"""
x::EscapeLattice

Expand Down Expand Up @@ -181,11 +187,23 @@ struct EscapeLattice
ThrownEscape::Bool
GlobalEscape::Bool
# TODO: ArgEscape::Int
FieldEscapes::Union{Nothing,Vector{FieldLattice}}
end

EscapeLattice(x::FieldLattice) = EscapeLattice(true, x.ReturnEscape, x.ThrownEscape, x.GlobalEscape, nothing)
FieldLattice(x::EscapeLattice) = FieldLattice(x.ReturnEscape, x.ThrownEscape, x.GlobalEscape)

# we need to make sure this `==` operator corresponds to lattice equality rather than object equality,
# otherwise `propagate_changes` can't detect the convergence
function Base.:(==)(x::EscapeLattice, y::EscapeLattice)
return x.Analyzed === y.Analyzed &&
x.ReturnEscape == y.ReturnEscape &&
x.ThrownEscape === y.ThrownEscape &&
x.GlobalEscape === y.GlobalEscape &&
x.FieldEscapes == y.FieldEscapes
end
function Base.:(==)(x::FieldLattice, y::FieldLattice)
return x.ReturnEscape == y.ReturnEscape &&
x.ThrownEscape === y.ThrownEscape &&
x.GlobalEscape === y.GlobalEscape
end
Expand All @@ -194,17 +212,20 @@ end
# precompute default values in order to eliminate computations at callsites
const NO_RETURN = BitSet()
const ARGUMENT_RETURN = BitSet(0)
NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false) # not formally part of the lattice
NoEscape() = EscapeLattice(true, NO_RETURN, false, false)
ReturnEscape(pcs::BitSet) = EscapeLattice(true, pcs, false, false)
const NO_FIELD_ESCAPE = nothing # EscapeLattice[]
NotAnalyzed() = EscapeLattice(false, NO_RETURN, false, false, NO_FIELD_ESCAPE) # not formally part of the lattice
NoEscape() = EscapeLattice(true, NO_RETURN, false, false, NO_FIELD_ESCAPE)
ReturnEscape(pcs::BitSet) = EscapeLattice(true, pcs, false, false, NO_FIELD_ESCAPE)
ReturnEscape(pc::Int) = ReturnEscape(BitSet(pc))
ArgumentReturnEscape() = ReturnEscape(ARGUMENT_RETURN)
ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false)
GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true)
ThrownEscape() = EscapeLattice(true, NO_RETURN, true, false, NO_FIELD_ESCAPE)
GlobalEscape() = EscapeLattice(true, NO_RETURN, false, true, NO_FIELD_ESCAPE)
FieldEscapes(finfos::Vector{FieldLattice}) = EscapeLattice(true, NO_RETURN, false, false, finfos)
NoFieldEscapes(info::EscapeLattice) = EscapeLattice(info.Analyzed, info.ReturnEscape, info.ThrownEscape, info.GlobalEscape, NO_FIELD_ESCAPE)
let
all_return = BitSet(0:100_000)
global AllReturnEscape() = ReturnEscape(all_return) # used for `show`
global AllEscape() = EscapeLattice(true, all_return, true, true)
global AllEscape() = EscapeLattice(true, all_return, true, true, NO_FIELD_ESCAPE)
end

# Convenience names for some ⊑ queries
Expand All @@ -222,7 +243,7 @@ has_return_escape(x::EscapeLattice) = !isempty(x.ReturnEscape)
has_return_escape(x::EscapeLattice, pc::Int) = pc in x.ReturnEscape
has_thrown_escape(x::EscapeLattice) = x.ThrownEscape
has_global_escape(x::EscapeLattice) = x.GlobalEscape
has_all_escape(x::EscapeLattice) = AllEscape() == x
has_all_escape(x::EscapeLattice) = AllEscape() x

"""
can_elide_finalizer(x::EscapeLattice, pc::Int) -> Bool
Expand All @@ -238,35 +259,57 @@ function can_elide_finalizer(x::EscapeLattice, pc::Int)
return pc ∉ x.ReturnEscape
end

# NOTE this partial order doens't take `FieldEscapes` into account at all
function ⊑(x::EscapeLattice, y::EscapeLattice)
if x.Analyzed ≤ y.Analyzed &&
x.ReturnEscape ⊆ y.ReturnEscape &&
x.ThrownEscape ≤ y.ThrownEscape &&
x.GlobalEscape ≤ y.GlobalEscape
return true
return true
end
return false
end
⋤(x::EscapeLattice, y::EscapeLattice) = ⊑(x, y) && !⊑(y, x)

function ⊔(x::EscapeLattice, y::EscapeLattice)
xf, yf = x.FieldEscapes, y.FieldEscapes
if isnothing(xf)
FieldEscapes = yf
elseif isnothing(yf)
FieldEscapes = xf
else
n = min(length(xf), length(yf))
FieldEscapes = Vector{FieldLattice}(undef, n)
for i in 1:n
FieldEscapes[i] = xf[i] ⊔ yf[i]
end
end
return EscapeLattice(
x.Analyzed | y.Analyzed,
x.ReturnEscape ∪ y.ReturnEscape,
x.ThrownEscape | y.ThrownEscape,
x.GlobalEscape | y.GlobalEscape,
FieldEscapes,
)
end

function ⊓(x::EscapeLattice, y::EscapeLattice)
return EscapeLattice(
x.Analyzed & y.Analyzed,
x.ReturnEscape ∩ y.ReturnEscape,
x.ThrownEscape & y.ThrownEscape,
x.GlobalEscape & y.GlobalEscape,
function ⊔(x::FieldLattice, y::FieldLattice)
return FieldLattice(
x.ReturnEscape ∪ y.ReturnEscape,
x.ThrownEscape | y.ThrownEscape,
x.GlobalEscape | y.GlobalEscape,
)
end

# NOTE unmaintained now
# function ⊓(x::EscapeLattice, y::EscapeLattice)
# return EscapeLattice(
# x.Analyzed & y.Analyzed,
# x.ReturnEscape ∩ y.ReturnEscape,
# x.ThrownEscape & y.ThrownEscape,
# x.GlobalEscape & y.GlobalEscape,
# )
# end

# TODO setup a more effient struct for cache
# which can discard escape information on SSS values and arguments that don't join dispatch signature

Expand Down Expand Up @@ -319,10 +362,20 @@ function find_escapes(ir::IRCode, nargs::Int)
state = EscapeState(length(ir.argtypes), nargs, nstmts)
changes = Changes() # stashes changes that happen at current statement

local assertion_counter = 0
while true
local anyupdate = false

for pc in nstmts:-1:1
if (assertion_counter += 1) > 10nstmts
Core.eval(Main, quote
ir = $ir
nargs = $nargs
state = $state
end)
throw((assertion_counter, nstmts))
end

stmt = stmts.inst[pc]

# we escape statements with the `ThrownEscape` property using the effect-freeness
Expand All @@ -347,7 +400,17 @@ function find_escapes(ir::IRCode, nargs::Int)
info = NoEscape()
add_change!(SSAValue(pc), ir, info, changes) # we will be interested in if this allocation escapes or not
end
add_changes!(stmt.args[2:end], ir, info, changes)
args = stmt.args[2:end]
add_changes!(args, ir, NoFieldEscapes(info), changes)
finfos = info.FieldEscapes
if !isnothing(finfos)
n = min(length(args), length(finfos))
for i in 1:n
arg = args[i]
finfo = EscapeLattice(finfos[i])
add_change!(arg, ir, finfo, changes)
end
end
elseif head === :splatnew
info = state.ssavalues[pc]
if info == NotAnalyzed()
Expand Down Expand Up @@ -537,13 +600,18 @@ end
# context of the caller frame using the escape information imposed on the return value (`retinfo`)
function from_interprocedural(arginfo::EscapeLattice, retinfo::EscapeLattice)
ar = arginfo.ReturnEscape
newarginfo = EscapeLattice(true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape)
fieldsinfo = arginfo.FieldEscapes
newfieldsinfo = isnothing(fieldsinfo) ? nothing : FieldLattice[
FieldLattice(from_interprocedural(EscapeLattice(finfo), retinfo)) for finfo in fieldsinfo]
newarginfo = EscapeLattice(
true, NO_RETURN, arginfo.ThrownEscape, arginfo.GlobalEscape, newfieldsinfo)
if ar == ARGUMENT_RETURN
# if this is simply passed as the call argument, we can discard the `ReturnEscape`
# information and just propagate the other escape information
return newarginfo
else
# if this can be a return value, we have to merge it with the escape information
# at the current SSA value
return newarginfo ⊔ retinfo
end
end
Expand Down Expand Up @@ -606,21 +674,63 @@ function escape_builtin!(::typeof(tuple), args::Vector{Any}, pc::Int, state::Esc
if info == NotAnalyzed()
info = NoEscape()
end
add_changes!(args[2:end], ir, info, changes)
args = args[2:end]
add_changes!(args, ir, info, changes)
fieldsinfo = info.FieldEscapes
if !isnothing(fieldsinfo)
n = min(length(args), length(fieldsinfo))
for i in 1:n
arg = args[i]
finfo = EscapeLattice(fieldsinfo[i])
add_change!(arg, ir, finfo, changes)
end
end
return true
end

# TODO don't propagate escape information to the 1st argument, but propagate information to aliased field
function try_compute_fieldidx_args(ir::IRCode, args::Vector{Any}, typ::DataType)
field = args[3]
if isa(field, QuoteNode)
field = field.value
elseif isa(field, Int)
# try to resolve other constants, e.g. global reference
else
field = argextype(field, ir, ir.sptypes, ir.argtypes)
if isa(field, Const)
field = field.val
else
return nothing
end
end
isa(field, Union{Int, Symbol}) || return nothing
return CC.try_compute_fieldidx(typ, field)
end

function escape_builtin!(::typeof(getfield), args::Vector{Any}, pc::Int, state::EscapeState, ir::IRCode, changes::Changes)
# only propagate info when the field itself is non-bitstype
isbitstype(widenconst(ir.stmts.type[pc])) && return true
info = state.ssavalues[pc]
if info == NotAnalyzed()
info = NoEscape()
end
# only propagate info when the field itself is non-bitstype
if !isbitstype(widenconst(ir.stmts.type[pc]))
add_changes!(args[2:end], ir, info, changes)
if length(args) ≥ 2
obj = args[2]
objt = widenconst(argextype(obj, ir, ir.sptypes, ir.argtypes))
if isa(objt, DataType)
idx = try_compute_fieldidx_args(ir, args, objt)
if !isnothing(idx)
n = CC.fieldcount_noerror(objt)
if !isnothing(n)
newinfo = FieldEscapes(FieldLattice[
FieldLattice(i == idx ? info : NoEscape()) for i in 1:n])
add_change!(obj, ir, newinfo, changes)
return true
end
end
end
add_change!(obj, ir, info, changes)
end
return true
return true # may throw, but we will handle that later
end

# entries
Expand Down Expand Up @@ -697,8 +807,9 @@ end
# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.)
__clear_caches!() = (__clear_code_cache!(); __clear_escape_cache!())

function get_name_color(x::EscapeLattice, symbol::Bool = false)
function get_name_color(x::EscapeLattice, short::Bool = false)
getname(x) = string(nameof(x))
FieldEscapes = x.FieldEscapes
if x == NotAnalyzed()
name, color = (getname(NotAnalyzed), '◌'), :plain
elseif x == NoEscape()
Expand All @@ -717,7 +828,13 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false)
else
name, color = (nothing, '*'), :red
end
return (symbol ? last(name) : first(name), color)
if short
s = last(name)
return string(s, isnothing(FieldEscapes) ? ' ' : '′'), color
else
s = first(name)
return (isnothing(s) ? s : string(s, isnothing(FieldEscapes) ? "" : '′')), color
end
end

function Base.show(io::IO, x::EscapeLattice)
Expand Down
Loading