From 1d07f382bd9ccfa5b702ecef4542032dfe5e7178 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 2 Aug 2023 18:53:11 +0100 Subject: [PATCH] Prototype implementation of recursion (#164) * Prototype implementation * Bump patch version * Improve docstring * Use rng * Bencehmark + cache --- Project.toml | 2 +- perf/benchmark.jl | 16 ++++++++++++++++ perf/p0.jl | 10 +++++++--- perf/p2.jl | 3 ++- src/tapedfunction.jl | 31 ++++++++++++++++++++++++------- test/tf.jl | 11 +++++++++++ 6 files changed, 61 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index cd87c256..bc778669 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.6" +version = "0.8.7" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" diff --git a/perf/benchmark.jl b/perf/benchmark.jl index 2b2312c0..25df7e93 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -109,6 +109,22 @@ benchmark_driver!(neural_net, xs...) #################################################################### +layer(w, x) = relu(w * x) +Libtask.is_primitive(::typeof(layer), w, x) = false + +function neural_net(w1, w2, w3, x1, callback=nothing) + x2 = layer(w1, x1) + x3 = layer(w2, x2) + ret = sigmoid(LinearAlgebra.dot(w3, x3)) + callback !== nothing && callback(ret) + return ret +end + +xs = (randn(10,10), randn(10,10), randn(10), rand(10)) +benchmark_driver!(neural_net, xs...) + +#################################################################### + println("======= breakdown benchmark =======") x = rand(100000) diff --git a/perf/p0.jl b/perf/p0.jl index 81ce7fd2..21bfb535 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -1,5 +1,5 @@ using Libtask -using Turing, DynamicPPL, AdvancedPS +using Turing, DynamicPPL, AdvancedPS, Random using BenchmarkTools @model gdemo(x, y) = begin @@ -13,7 +13,9 @@ end # Case 1: Sample from the prior. -m = Turing.Core.TracedModel(gdemo(1.5, 2.), SampleFromPrior(), VarInfo()) +m = Turing.Core.TracedModel( + gdemo(1.5, 2.), SampleFromPrior(), VarInfo(), MersenneTwister(123456) +); f = m.evaluator[1]; args = m.evaluator[2:end]; @@ -26,7 +28,9 @@ println("Run a tape...") @btime t.tf(args...) # Case 2: SMC sampler -m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo()); +m = Turing.Core.TracedModel( + gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo(), MersenneTwister(123456) +); f = m.evaluator[1]; args = m.evaluator[2:end]; diff --git a/perf/p2.jl b/perf/p2.jl index a95b3d3c..f3d18493 100644 --- a/perf/p2.jl +++ b/perf/p2.jl @@ -52,7 +52,8 @@ Random.seed!(2) iterations = 500 model_fun = infiniteGMM(data) -m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo()) +rng = MersenneTwister(123456) +m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng) f = m.evaluator[1] args = m.evaluator[2:end] diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 2d947aaf..7660b90b 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -162,7 +162,7 @@ function (tf::TapedFunction)(args...; callback=nothing, continuation=false) # run the raw tape while true ins = tf.tape[tf.counter] - ins(tf) + ins(tf, callback) callback !== nothing && callback() tf.retval_binding_slot != 0 && break end @@ -211,12 +211,17 @@ function Base.show(io::IO, instr::CondGotoInstruction) println(io, "CondGotoInstruction(", instr.condition, ", dest=", instr.dest, ")") end -function (instr::Instruction{F})(tf::TapedFunction) where F +function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F # catch run-time exceptions / errors. try func = F === Int ? _lookup(tf, instr.func) : instr.func inputs = map(x -> _lookup(tf, x), instr.input) - output = func(inputs...) + output = if is_primitive(func, inputs...) + func(inputs...) + else + tf_inner = TapedFunction(func, inputs..., cache=true) + tf_inner(inputs...; callback=callback) + end _update_var!(tf, instr.output, output) tf.counter += 1 catch e @@ -227,11 +232,11 @@ function (instr::Instruction{F})(tf::TapedFunction) where F end end -function (instr::GotoInstruction)(tf::TapedFunction) +function (instr::GotoInstruction)(tf::TapedFunction, callback=nothing) tf.counter = instr.dest end -function (instr::CondGotoInstruction)(tf::TapedFunction) +function (instr::CondGotoInstruction)(tf::TapedFunction, callback=nothing) cond = _lookup(tf, instr.condition) if cond tf.counter += 1 @@ -240,11 +245,11 @@ function (instr::CondGotoInstruction)(tf::TapedFunction) end end -function (instr::ReturnInstruction)(tf::TapedFunction) +function (instr::ReturnInstruction)(tf::TapedFunction, callback=nothing) tf.retval_binding_slot = instr.arg end -function (instr::NOOPInstruction)(tf::TapedFunction) +function (instr::NOOPInstruction)(tf::TapedFunction, callback=nothing) tf.counter += 1 end @@ -450,6 +455,18 @@ function translate!!(var, line, bindings, isconst, ir) throw(ErrorException("Unknown IR code")) end +## primitives. + +""" + is_primitive(f, args...) + +Should a function be recursed into, or should it be treated as a single instruction, when +encountered inside of a `TapedFunction`. If `is_primitive(f, args...)` is `true`, then +the instruction will not be traced into. Conversely, if `is_primitive(f, args...)` is +`false`, a `TapedFunction` is constructed. +""" +is_primitive(f, args...) = true + ## copy Bindings, TapedFunction """ diff --git a/test/tf.jl b/test/tf.jl index f1fd87c5..28811cf3 100644 --- a/test/tf.jl +++ b/test/tf.jl @@ -1,5 +1,10 @@ using Libtask +foo(x) = sin(cos(x)) +bar(x) = foo(foo(x)) + +Libtask.is_primitive(::typeof(foo), args...) = false + @testset "tapedfunction" begin # Test case 1: stack allocated objects are deep copied. @testset "Instruction{typeof(__new__)}" begin @@ -31,4 +36,10 @@ using Libtask @test typeof(r) === Float64 end + @testset "recurse into function" begin + tf = Libtask.TapedFunction(bar, 5.0) + count = 0 + tf(4.0; callback=() -> (count += 1)) + @test count == 9 + end end