-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype implementation of recursion #164
Conversation
This is really awesome work @willtebbutt ! Yeah, it seems like it works perfectly:) julia> using Libtask
julia> f(x) = (produce(2x); return 2x)
f (generic function with 1 method)
julia> g(x) = f(f(x))
g (generic function with 1 method)
julia> Libtask.is_primitive(::typeof(f), x) = false
julia> task = Libtask.TapedTask(g, 1)
TapedTask{typeof(g), Tuple{Int64}}(Task (runnable) @0x00007fe88b81eef0, TapedFunction:
* .func => g
* .ir =>
------------------
CodeInfo(
1 ─ %1 = Main.f(x)::Int64
│ %2 = Main.f(%1)::Int64
└── return %2
)
------------------
, (1,), Channel{Any}(0), Channel{Int64}(0), Any[])
julia> consume(task)
2
julia> consume(task)
4
julia> consume(task) |
Isn't this already there in |
I did start with this idea at some point but then disregarded it because I realized I'd have to call Also, I guess in the current state of things it probably breaks the |
Isn't it annoying/problematic in practice that one has to define |
We should only need to do this for And would we really consider using Casette as "simplifying things"? 😅 |
As soon as more intermediate functions (eg, custom handwritten models) are involved more definitions of is_primitive are needed. And for general non-Turing/DynamicPPL-use they should be needed even more commonly, I assume. I'd say an automatic solution using Cassette is definitely simpler for a user than having to work with is_primitive. For instance, we also use Cassette-based analysis in SciML to figure out if differential equations have branches that prevent tape compilation with ReverseDiff: https://github.com/SciML/SciMLSensitivity.jl/blob/72eed5c4b21d2815dfdbccea97475eaf68f9b6a2/src/hasbranching.jl https://github.com/SciML/SciMLSensitivity.jl/blob/72eed5c4b21d2815dfdbccea97475eaf68f9b6a2/src/concrete_solve.jl#L31 |
|
@torfjelde any idea what the test failures are about? I can't tell whether it's something that I've done, or something unrelated |
The test failures in Benchmarks and Microintegration are known and reproducible on the master branch -- it is probably not related to this PR. |
I'm trying to fix the benchmarks stuff so that I can check that this PR hasn't caused regressions. I've fixed the first issue, but appear to have encountered another one. ERROR: LoadError: KeyError: key :__trace not found
Stacktrace:
[1] getindex(d::IdDict{Any, Any}, key::Any)
@ Base ./iddict.jl:108
[2] current_trace()
@ AdvancedPS ~/.julia/packages/AdvancedPS/Vox9w/src/model.jl:55
[3] assume(rng::MersenneTwister, spl::Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, dist::InverseGamma{Float64}, vn::VarName{:σ, Setfield.IdentityLens}, __vi__::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
@ Turing.Inference ~/work/Libtask.jl/Libtask.jl/downstream/src/inference/AdvancedSMC.jl:335
[4] tilde_assume
@ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:49 [inlined]
[5] tilde_assume
@ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:46 [inlined]
[6] tilde_assume
@ ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:31 [inlined]
[7] tilde_assume!!(context::SamplingContext{Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, DefaultContext, MersenneTwister}, right::InverseGamma{Float64}, vn::VarName{:σ, Setfield.IdentityLens}, vi::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64})
@ DynamicPPL ~/.julia/packages/DynamicPPL/oJMmE/src/context_implementations.jl:117
[8] gdemo(__model__::Model{typeof(gdemo), (:x, :y), (), (), Tuple{Float64, Float64}, Tuple{}, DefaultContext}, __varinfo__::UntypedVarInfo{DynamicPPL.Metadata{Dict{VarName, Int64}, Vector{Distribution}, Vector{VarName}, Vector{Real}, Vector{Set{DynamicPPL.Selector}}}, Float64}, __context__::SamplingContext{Sampler{SMC{(), AdvancedPS.ResampleWithESSThreshold{typeof(AdvancedPS.resample_systematic), Int64}}}, DefaultContext, MersenneTwister}, x::Float64, y::Float64)
@ Main ~/work/Libtask.jl/Libtask.jl/perf/p0.jl:7
[9] var"##core#457"()
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:489
[10] var"##sample#458"(::Tuple{}, __params::BenchmarkTools.Parameters)
@ Main ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:495
[11] _run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; verbose::Bool, pad::String, kwargs::Base.Pairs{Symbol, Integer, NTuple{4, Symbol}, NamedTuple{(:samples, :evals, :gctrial, :gcsample), Tuple{Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:99
[12] #invokelatest#2
@ ./essentials.jl:818 [inlined]
[13] invokelatest
@ ./essentials.jl:813 [inlined]
[14] #run_result#45
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[15] run_result
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:34 [inlined]
[16] run(b::BenchmarkTools.Benchmark, p::BenchmarkTools.Parameters; progressid::Nothing, nleaves::Float64, ndone::Float64, kwargs::Base.Pairs{Symbol, Integer, NTuple{5, Symbol}, NamedTuple{(:verbose, :samples, :evals, :gctrial, :gcsample), Tuple{Bool, Int64, Int64, Bool, Bool}}})
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117
[17] run (repeats 2 times)
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:117 [inlined]
[18] #warmup#54
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:169 [inlined]
[19] warmup(item::BenchmarkTools.Benchmark)
@ BenchmarkTools ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:168
[20] top-level scope
@ ~/.julia/packages/BenchmarkTools/0owsb/src/execution.jl:575
[21] include(fname::String)
@ Base.MainInclude ./client.jl:478
[22] top-level scope
@ ~/work/Libtask.jl/Libtask.jl/perf/runtests.jl:2
[23] include(fname::String)
@ Base.MainInclude ./client.jl:478
[24] top-level scope
@ ~/work/_temp/07fbdc08-5b3e-446e-b3f5-fa8632e46022:7
in expression starting at /home/runner/work/Libtask.jl/Libtask.jl/perf/p0.jl:38
in expression starting at /home/runner/work/Libtask.jl/Libtask.jl/perf/runtests.jl:2
in expression starting at /home/runner/work/_temp/07fbdc08-5b3e-446e-b3f5-fa8632e46022:2 I don't really understand the context for this error, as I've not worked with |
@willtebbutt It's a leftover issue in Turing. TuringLang/Turing.jl#2057 should fix it. |
The performance profiling code should be working now with |
@yebai looks like CI is now passing. Thanks for sorting that. Do you have thoughts on the kinds of additional benchmarks that we should add? |
Do we want to benchmark the overheads of sub-tapes? |
Benchmark results for subtape: benchmarking neural_net...
Run Original Function: 571.228 ns (4 allocations: 576 bytes)
Run TapedFunction: 1.530 μs (6 allocations: 608 bytes)
Run TapedFunction (compiled): 1.170 μs (14 allocations: 864 bytes)
Run TapedTask: #produce=1; 328.152 μs (99 allocations: 9.56 KiB)
benchmarking neural_net...
Run Original Function: 554.124 ns (4 allocations: 576 bytes)
Run TapedFunction: 12.464 μs (52 allocations: 2.66 KiB)
Run TapedFunction (compiled): 11.418 μs (58 allocations: 2.84 KiB)
Run TapedTask: #produce=1; 321.877 μs (99 allocations: 9.56 KiB) Both functions do the same thing, the latter jusrt makes use of the sub-tape mechanism (see the update to the benchmarks for details). There's a moderate amount of overhead introduced via the primitive mechanism. I don't know how much we care. I also don't know if there's a good way to reduce it. |
We can come back to optimise it if it causes pain. I suggest that we merge this branch into |
A basic implementation of recursion against which tweaks to Turing's submodel interface can be tested.
TODO:
add context variable to enable different primitives in different situationsis_primitive
TapedFunction
based onfunc
and argument types to avoid re-tracing each timeMaybe TODO:
@torfjelde when you get a minute, could you please check to see whether this does what you need it to do?