Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype implementation of recursion #164

Merged
merged 6 commits into from
Aug 2, 2023

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Jun 21, 2023

A basic implementation of recursion against which tweaks to Turing's submodel interface can be tested.

TODO:

  • add context variable to enable different primitives in different situations
  • improve documentation for is_primitive
  • bump version
  • implement cache for TapedFunction based on func and argument types to avoid re-tracing each time

Maybe TODO:

  • find a better way of testing

@torfjelde when you get a minute, could you please check to see whether this does what you need it to do?

@torfjelde
Copy link
Member

This is really awesome work @willtebbutt ! Yeah, it seems like it works perfectly:)

julia> using Libtask

julia> f(x) = (produce(2x); return 2x)
f (generic function with 1 method)

julia> g(x) = f(f(x))
g (generic function with 1 method)

julia> Libtask.is_primitive(::typeof(f), x) = false

julia> task = Libtask.TapedTask(g, 1)
TapedTask{typeof(g), Tuple{Int64}}(Task (runnable) @0x00007fe88b81eef0, TapedFunction:
* .func => g
* .ir   =>
------------------
CodeInfo(
1%1 = Main.f(x)::Int64%2 = Main.f(%1)::Int64
└──      return %2
)
------------------
, (1,), Channel{Any}(0), Channel{Int64}(0), Any[])

julia> consume(task)
2

julia> consume(task)
4

julia> consume(task)

@torfjelde
Copy link
Member

implement cache for TapedFunction based on func and argument types to avoid re-tracing each time

Isn't this already there in Libtask.TRCache?

@torfjelde
Copy link
Member

torfjelde commented Jun 21, 2023

I did start with this idea at some point but then disregarded it because I realized I'd have to call TapedFunction(...) every time, but maybe you're right that we can just deal with this by caching the methods. No matter, it's still waaaay more preferable that we have correctness vs. performance:) I also didn't realize how simple adding the suppport in this way would be 👍

Also, I guess in the current state of things it probably breaks the Libtask.compile (but it's unclear to me how helpful the current version of this is anyways, since it doesn't specialize on the instruction).

@devmotion
Copy link
Member

Isn't it annoying/problematic in practice that one has to define is_primitive for all methods manually? Could we use eg Casette to figure out if a function contains produce statements and has to be recursed into?

@torfjelde
Copy link
Member

We should only need to do this for _evaluate!! and potentially a few other methods; otherwise this shouldn't be necessary.

And would we really consider using Casette as "simplifying things"? 😅

@devmotion
Copy link
Member

As soon as more intermediate functions (eg, custom handwritten models) are involved more definitions of is_primitive are needed. And for general non-Turing/DynamicPPL-use they should be needed even more commonly, I assume.

I'd say an automatic solution using Cassette is definitely simpler for a user than having to work with is_primitive. For instance, we also use Cassette-based analysis in SciML to figure out if differential equations have branches that prevent tape compilation with ReverseDiff: https://github.com/SciML/SciMLSensitivity.jl/blob/72eed5c4b21d2815dfdbccea97475eaf68f9b6a2/src/hasbranching.jl https://github.com/SciML/SciMLSensitivity.jl/blob/72eed5c4b21d2815dfdbccea97475eaf68f9b6a2/src/concrete_solve.jl#L31

@yebai
Copy link
Member

yebai commented Jun 22, 2023

As soon as more intermediate functions (e.g., custom handwritten models) are involved, more definitions of is_primitive are needed. And for general non-Turing/DynamicPPL-use they should be needed even more commonly, I assume.

Libtask is intended to be transparent to end users, so we don't anticipate users will overload is_primitive. Only DynamicPPL developers need to overload this function for models/submodel evaluation functions. So keeping the code here lightweight is preferable to reduce maintenance burdens.

@willtebbutt
Copy link
Member Author

@torfjelde any idea what the test failures are about? I can't tell whether it's something that I've done, or something unrelated

@yebai
Copy link
Member

yebai commented Jul 3, 2023

The test failures in Benchmarks and Microintegration are known and reproducible on the master branch -- it is probably not related to this PR.

@willtebbutt
Copy link
Member Author

I'm trying to fix the benchmarks stuff so that I can check that this PR hasn't caused regressions. I've fixed the first issue, but appear to have encountered another one.

ERROR: LoadError: KeyError: key :__trace not found
Stacktrace:
  [1] getindex(d::IdDict{Any, Any}, key::Any)
    @ Base ./iddict.jl:108
  [2] current_trace()
    @ AdvancedPS ~/.julia/packages/AdvancedPS/Vox9w/src/model.jl:55
  [3] assume(rng::MersenneTwister, spl::Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, dist::InverseGamma{Float64}, vn::VarName{:σ, Setfield.IdentityLens}, __vi__::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ Turing.Inference ~/work/Libtask.jl/Libtask.jl/downstream/src/inference/AdvancedSMC.jl:335
  [4] tilde_assume
    @ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:49 [inlined]
  [5] tilde_assume
    @ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:46 [inlined]
  [6] tilde_assume
    @ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:31 [inlined]
  [7] tilde_assume!!(context::SamplingContext{Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, DefaultContext, MersenneTwister}, right::InverseGamma{Float64}, vn::VarName{:σ, Setfield.IdentityLens}, vi::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:117
  [8] gdemo(__model__::Model{typeof(gdemo), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}, __varinfo__::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, __context__::SamplingContext{Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, DefaultContext, MersenneTwister}, x::Float64, y::Float64)
    @ Main ~/work/Libtask.jl/Libtask.jl/perf/p0.jl:7
  [9] var"##core#457"()
    @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
 [10] var"##sample#458"(::Tuple{}, __params::BenchmarkTools.Parameters)
    @ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
 [11] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
 [12] #invokelatest#2
    @ ./essentials.jl:818 [inlined]
 [13] invokelatest
    @ ./essentials.jl:813 [inlined]
 [14] #run_result#45
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
 [15] run_result
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
 [16] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
 [17] run (repeats 2 times)
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
 [18] #warmup#54
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
 [19] warmup(item::BenchmarkTools.Benchmark)
    @ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
 [20] top-level scope
    @ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:575
 [21] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [22] top-level scope
    @ ~/work/Libtask.jl/Libtask.jl/perf/runtests.jl:2
 [23] include(fname::String)
    @ Base.MainInclude ./client.jl:478
 [24] top-level scope
    @ ~/work/_temp/07fbdc08-5b3e-446e-b3f5-fa8632e46022:7
in expression starting at /home/runner/work/Libtask.jl/Libtask.jl/perf/p0.jl:38
in expression starting at /home/runner/work/Libtask.jl/Libtask.jl/perf/runtests.jl:2
in expression starting at /home/runner/work/_temp/07fbdc08-5b3e-446e-b3f5-fa8632e46022:2

I don't really understand the context for this error, as I've not worked with AdvancedPS. @yebai @torfjelde do you have any idea why this would occur?

@yebai
Copy link
Member

yebai commented Jul 31, 2023

@willtebbutt It's a leftover issue in Turing. TuringLang/Turing.jl#2057 should fix it.

@yebai
Copy link
Member

yebai commented Jul 31, 2023

The performance profiling code should be working now with Turing@v0.28.1.

@willtebbutt
Copy link
Member Author

@yebai looks like CI is now passing. Thanks for sorting that.

Do you have thoughts on the kinds of additional benchmarks that we should add?

@yebai
Copy link
Member

yebai commented Aug 1, 2023

Do we want to benchmark the overheads of sub-tapes?

@willtebbutt willtebbutt closed this Aug 2, 2023
@willtebbutt willtebbutt reopened this Aug 2, 2023
@willtebbutt
Copy link
Member Author

willtebbutt commented Aug 2, 2023

Benchmark results for subtape:

benchmarking neural_net...
  Run Original Function:  571.228 ns (4 allocations: 576 bytes)
  Run TapedFunction:  1.530 μs (6 allocations: 608 bytes)
  Run TapedFunction (compiled):  1.170 μs (14 allocations: 864 bytes)
  Run TapedTask: #produce=1;   328.152 μs (99 allocations: 9.56 KiB)
benchmarking neural_net...
  Run Original Function:  554.124 ns (4 allocations: 576 bytes)
  Run TapedFunction:  12.464 μs (52 allocations: 2.66 KiB)
  Run TapedFunction (compiled):  11.418 μs (58 allocations: 2.84 KiB)
  Run TapedTask: #produce=1;   321.877 μs (99 allocations: 9.56 KiB)

Both functions do the same thing, the latter jusrt makes use of the sub-tape mechanism (see the update to the benchmarks for details).

There's a moderate amount of overhead introduced via the primitive mechanism. I don't know how much we care. I also don't know if there's a good way to reduce it.

@yebai yebai changed the base branch from master to subtape August 2, 2023 17:46
@yebai
Copy link
Member

yebai commented Aug 2, 2023

We can come back to optimise it if it causes pain.

I suggest that we merge this branch into subtape branch (inside this repo). Then @torfjelde's branch can be merged into subtape so we can test more before merging into the master. In the future, please feel free to create branches inside the TuringLang repos so others can more easily work with them.

@willtebbutt willtebbutt merged commit 1d07f38 into TuringLang:subtape Aug 2, 2023
36 checks passed
@willtebbutt willtebbutt deleted the wct/recursion branch August 2, 2023 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants