Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise DynamicPPL Slightly and Better Zero Adjoint Functionality #242

Merged
merged 22 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
- 'integration_testing/array'
- 'integration_testing/turing'
- 'integration_testing/temporalgps'
- 'integration_testing/dynamic_ppl'
- 'interface'
steps:
- uses: actions/checkout@v4
Expand Down
9 changes: 6 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.44"
version = "0.2.45"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -19,12 +19,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
TapirCUDAExt = "CUDA"
TapirDynamicPPLExt = "DynamicPPL"
TapirJETExt = "JET"
TapirLogDensityProblemsADExt = "LogDensityProblemsAD"
TapirSpecialFunctionsExt = "SpecialFunctions"
Expand All @@ -49,7 +51,7 @@ Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
TemporalGPs = "0.6"
Turing = "0.32"
Turing = "0.34"
julia = "1.10"

[extras]
Expand All @@ -59,6 +61,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
Expand All @@ -72,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing", "TemporalGPs"]
1 change: 0 additions & 1 deletion bench/run_benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ using Tapir:
generate_hand_written_rrule!!_test_cases,
generate_derived_rrule!!_test_cases,
TestUtils,
PInterp,
_typeof

using Tapir.TestUtils: _deepcopy, to_benchmark
Expand Down
17 changes: 17 additions & 0 deletions ext/TapirDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module TapirDynamicPPLExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL, istrans
using Tapir: Tapir
else
using ..DynamicPPL: DynamicPPL, istrans
using ..Tapir: Tapir
end

using Tapir: DefaultCtx, CoDual, simple_zero_adjoint

# This is purely an optimisation.
Tapir.@is_primitive DefaultCtx Tuple{typeof(istrans), Vararg}
Tapir.rrule!!(f::CoDual{typeof(istrans)}, x::CoDual...) = simple_zero_adjoint(f, x...)

end # module
17 changes: 17 additions & 0 deletions src/codual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ end

@inline (pb::NoPullback)(_) = tuple_map(instantiate, pb.r)

"""
simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing `rrule!!`s for functions which produce adjoints which
always return zero. Equivalent to:
```julia
zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...)
```

WARNING: this is only correct if the output of `primal(f)(map(primal, x)...)` does not alias
anything in `f` or `x`. This is always the case if the result is a bits type, but more care
may be required if it is not.
"""
@inline function simple_zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}
return zero_fcodual(primal(f)(map(primal, x)...)), NoPullback(f, x...)
end

to_fwds(x::CoDual) = CoDual(primal(x), fdata(tangent(x)))

to_fwds(x::CoDual{Type{P}}) where {P} = CoDual{Type{P}, NoFData}(primal(x), NoFData())
Expand Down
34 changes: 24 additions & 10 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ struct ClosureCacheKey
key::Any
end

const GLOBAL_CLOSURE_CACHE = Dict{ClosureCacheKey, Any}()

struct TICache
dict::IdDict{Core.MethodInstance, Core.CodeInstance}
end
Expand All @@ -35,7 +33,7 @@ struct TapirInterpreter{C} <: CC.AbstractInterpreter
opt_params::CC.OptimizationParams=CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[],
code_cache::TICache=TICache(),
oc_cache::Dict{ClosureCacheKey, Any}=GLOBAL_CLOSURE_CACHE,
oc_cache::Dict{ClosureCacheKey, Any}=Dict{ClosureCacheKey, Any}(),
) where {C}
return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, oc_cache)
end
Expand All @@ -52,13 +50,29 @@ end

TapirInterpreter() = TapirInterpreter(DefaultCtx)

const PInterp = TapirInterpreter
# Globally cached interpreter. Should only be accessed via `get_tapir_interpreter`.
const GLOBAL_INTERPRETER = Ref(TapirInterpreter())

"""
get_tapir_interpreter()

Returns a `TapirInterpreter` appropriate for the current world age. Will use a cached
interpreter if one already exists for the current world age, otherwise creates a new one.
This is a very conservative approach to caching the interpreter, which reflects the
approach taken the the closure cache.
"""
function get_tapir_interpreter()
if GLOBAL_INTERPRETER[].world != Base.get_world_counter()
GLOBAL_INTERPRETER[] = TapirInterpreter()
end
return GLOBAL_INTERPRETER[]
end

CC.InferenceParams(interp::PInterp) = interp.inf_params
CC.OptimizationParams(interp::PInterp) = interp.opt_params
CC.get_world_counter(interp::PInterp) = interp.world
CC.get_inference_cache(interp::PInterp) = interp.inf_cache
function CC.code_cache(interp::PInterp)
CC.InferenceParams(interp::TapirInterpreter) = interp.inf_params
CC.OptimizationParams(interp::TapirInterpreter) = interp.opt_params
CC.get_world_counter(interp::TapirInterpreter) = interp.world
CC.get_inference_cache(interp::TapirInterpreter) = interp.inf_cache
function CC.code_cache(interp::TapirInterpreter)
return CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
end
function CC.get(wvc::CC.WorldView{TICache}, mi::Core.MethodInstance, default)
Expand Down Expand Up @@ -103,4 +117,4 @@ function CC.inlining_policy(
)
end

context_type(::PInterp{C}) where {C} = C
context_type(::TapirInterpreter{C}) where {C} = C
85 changes: 72 additions & 13 deletions src/interpreter/bbcode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,50 @@ function _remove_double_edges(ir::BBCode)
return BBCode(ir, new_blks)
end

"""
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a `SimpleDiGraph`, `g`, representing of the CFG associated to `blks`, where `blks`
comprises the collection of basic blocks associated to a `BBCode`.
This is a type from Graphs.jl, so constructing `g` makes it straightforward to analyse the
control flow structure of `ir` using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is `g`, and whose second element is a map from
the `ID` associated to each basic block in `ir`, to the `Int` corresponding to its node
index in `g`.
"""
function _build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}
node_ints = collect(eachindex(blks))
id_to_int = Dict(zip(map(blk -> blk.id, blks), node_ints))
successors = _compute_all_successors(blks)
g = SimpleDiGraph(length(blks))
for blk in blks, successor in successors[blk.id]
add_edge!(g, id_to_int[blk.id], id_to_int[successor])
end
return g, id_to_int
end

"""
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in `blks`, compute the distance from it to the entry point (the first
block. The distance is `typemax(Int)` if no path from the entry point to a given node.
"""
function _distance_to_entry(blks::Vector{BBlock})::Vector{Int}
g, id_to_int = _build_graph_of_cfg(blks)
return dijkstra_shortest_paths(g, id_to_int[blks[1].id]).dists
end

"""
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a `Vector` whose length is `length(blks)`. The `n`th element is `true` iff it is
possible for control flow to reach the `n`th block.
"""
_is_reachable(blks::Vector{BBlock})::Vector{Bool} = _distance_to_entry(blks) .< typemax(Int)

#=
_sort_blocks!(ir::BBCode)
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the
distance from block b to the entry point is defined to be the minimum number of basic
Expand All @@ -627,18 +669,8 @@ WARNING: use with care. Only use if you are confident that arbitrary re-ordering
blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in
`ir`.
=#
function _sort_blocks!(ir::BBCode)

node_ints = collect(eachindex(ir.blocks))
id_to_int = Dict(zip(map(blk -> blk.id, ir.blocks), node_ints))
ps = compute_all_predecessors(ir)
direct_predecessors = map(ir.blocks) do blk
return map(b -> Edge(id_to_int[b], id_to_int[blk.id]), ps[blk.id])
end
g = SimpleDiGraph(reduce(vcat, direct_predecessors))

d = dijkstra_shortest_paths(g, id_to_int[ir.blocks[1].id]).dists
I = sortperm(d)
function _sort_blocks!(ir::BBCode)::BBCode
I = sortperm(_distance_to_entry(ir.blocks))
ir.blocks .= ir.blocks[I]
return ir
end
Expand Down Expand Up @@ -761,3 +793,30 @@ function _find_id_uses!(d::Dict{ID, Bool}, x::ReturnNode)
end
_find_id_uses!(d::Dict{ID, Bool}, x::QuoteNode) = nothing
_find_id_uses!(d::Dict{ID, Bool}, x) = nothing

"""
remove_unreachable_blocks(ir::BBCode)::BBCode

If a basic block in `ir` cannot possibly be reached during execution, then it can be safely
removed from `ir` without changing its functionality.
A block is unreachable if either:
1. it has no predecessors _and_ it is not the first block, or
2. all of its predecessors are themselves unreachable.

For example, consider the following IR:
```jldoctest remove_unreachable_blocks
julia> ir = Tapir.ircode(
Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],
Any[Any, Any, Any],
);
```
There is no possible way to reach the second basic block (lines 2 and 3). Applying this
function will therefore remove it, yielding the following:
```jldoctest remove_unreachable_blocks
julia> Tapir.IRCode(Tapir.remove_unreachable_blocks(Tapir.BBCode(ir)))
1 1 ─ return nothing
```
"""
remove_unreachable_blocks(ir::BBCode) = BBCode(ir, _remove_unreachable_blocks(ir.blocks))

_remove_unreachable_blocks(blks::Vector{BBlock}) = blks[_is_reachable(blks)]
43 changes: 29 additions & 14 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
associated to this information.
=#
struct ADInfo
interp::PInterp
interp::TapirInterpreter
block_stack_id::ID
block_stack::BlockStack
entry_id::ID
Expand All @@ -136,7 +136,7 @@
# The constructor that you should use for ADInfo if you don't have a BBCode lying around.
# See the definition of the ADInfo struct for info on the arguments.
function ADInfo(
interp::PInterp,
interp::TapirInterpreter,
arg_types::Dict{Argument, Any},
ssa_insts::Dict{ID, NewInstruction},
is_used_dict::Dict{ID, Bool},
Expand All @@ -163,7 +163,7 @@

# The constructor you should use for ADInfo if you _do_ have a BBCode lying around. See the
# ADInfo struct for information regarding `interp` and `safety_on`.
function ADInfo(interp::PInterp, ir::BBCode, safety_on::Bool)
function ADInfo(interp::TapirInterpreter, ir::BBCode, safety_on::Bool)
arg_types = Dict{Argument, Any}(
map(((n, t),) -> (Argument(n) => _type(t)), enumerate(ir.argtypes))
)
Expand Down Expand Up @@ -812,23 +812,33 @@
Helper method. Only uses static information from `args`.
"""
function build_rrule(args...; safety_on=false)
return build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)); safety_on)
interp = get_tapir_interpreter()
return build_rrule(interp, _typeof(TestUtils.__get_primals(args)); safety_on)
end

const TAPIR_INFERENCE_LOCK = ReentrantLock()

"""
build_rrule(interp::PInterp{C}, sig_or_mi; safety_on=false) where {C}
build_rrule(interp::TapirInterpreter{C}, sig_or_mi; safety_on=false) where {C}

Returns a `DerivedRule` which is an `rrule!!` for `sig_or_mi` in context `C`. See the
docstring for `rrule!!` for more info.

If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s.
"""
function build_rrule(
interp::PInterp{C}, sig_or_mi; safety_on=false, silence_safety_messages=true
interp::TapirInterpreter{C}, sig_or_mi; safety_on=false, silence_safety_messages=true
) where {C}

# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
throw(ArgumentError(

Check warning on line 836 in src/interpreter/s2s_reverse_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_reverse_mode_ad.jl#L836

Added line #L836 was not covered by tests
"World age associated to interp is behind current world age. Please " *
"a new interpreter for the current world age."
))
end

# If we're compiling in safe mode, let the user know by default.
if !silence_safety_messages && safety_on
@info "Compiling rule for $sig_or_mi in safe mode. Disable for best performance."
Expand Down Expand Up @@ -1062,14 +1072,15 @@
main_blocks = map(ad_stmts_blocks, enumerate(ir.blocks)) do (blk_id, ad_stmts), (n, blk)
if is_unreachable_return_node(terminator(blk))
rvs_stmts = [(ID(), new_inst(nothing))]
return BBlock(blk_id, rvs_stmts)
else
rvs_stmts = reduce(vcat, [x.rvs for x in reverse(ad_stmts)])
rvs_ad_stmts = reduce(vcat, [x.rvs for x in reverse(ad_stmts)])
pred_ids = vcat(ps[blk.id], n == 1 ? [info.entry_id] : ID[])
tmp = pred_is_unique_pred[blk_id]
additional_stmts, new_blocks = conclude_rvs_block(blk, pred_ids, tmp, info)
rvs_block = BBlock(blk_id, vcat(rvs_ad_stmts, additional_stmts))
return vcat(rvs_block, new_blocks)
end
pred_ids = vcat(ps[blk.id], n == 1 ? [info.entry_id] : ID[])
tmp = pred_is_unique_pred[blk_id]
additional_stmts, new_blocks = conclude_rvs_block(blk, pred_ids, tmp, info)
rvs_block = BBlock(blk_id, vcat(rvs_stmts, additional_stmts))
return vcat(rvs_block, new_blocks)
end
main_blocks = vcat(main_blocks...)

Expand Down Expand Up @@ -1127,9 +1138,13 @@
),
)

# Create and return `BBCode` for the pullback.
# Create and return `BBCode` for the pullback. Sort the blocks and remove any blocks
# which are unreachable, in the sense that they have no predecessors (except the entry
# block). This ought not to be necessary, but _appears_ to be necessary in order to
# avoid annoying the Julia compiler.
blks = vcat(entry_block, main_blocks, exit_block)
return _sort_blocks!(BBCode(blks, arg_types, ir.sptypes, ir.linetable, ir.meta))
pb_ir = BBCode(blks, arg_types, ir.sptypes, ir.linetable, ir.meta)
return remove_unreachable_blocks(_sort_blocks!(pb_ir))
end

#=
Expand Down
Loading
Loading