Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype implementation of recursion #167

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.8.6"
version = "0.8.7"

[deps]
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand Down
16 changes: 16 additions & 0 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,22 @@ benchmark_driver!(neural_net, xs...)

####################################################################

layer(w, x) = relu(w * x)
Libtask.is_primitive(::typeof(layer), w, x) = false

function neural_net(w1, w2, w3, x1, callback=nothing)
x2 = layer(w1, x1)
x3 = layer(w2, x2)
ret = sigmoid(LinearAlgebra.dot(w3, x3))
callback !== nothing && callback(ret)
return ret
end

xs = (randn(10,10), randn(10,10), randn(10), rand(10))
benchmark_driver!(neural_net, xs...)

####################################################################

println("======= breakdown benchmark =======")

x = rand(100000)
Expand Down
10 changes: 7 additions & 3 deletions perf/p0.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Libtask
using Turing, DynamicPPL, AdvancedPS
using Turing, DynamicPPL, AdvancedPS, Random
using BenchmarkTools

@model gdemo(x, y) = begin
Expand All @@ -13,7 +13,9 @@ end


# Case 1: Sample from the prior.
m = Turing.Core.TracedModel(gdemo(1.5, 2.), SampleFromPrior(), VarInfo())
m = Turing.Core.TracedModel(
gdemo(1.5, 2.), SampleFromPrior(), VarInfo(), MersenneTwister(123456)
);
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand All @@ -26,7 +28,9 @@ println("Run a tape...")
@btime t.tf(args...)

# Case 2: SMC sampler
m = Turing.Core.TracedModel(gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo());
m = Turing.Core.TracedModel(
gdemo(1.5, 2.), Sampler(SMC(50)), VarInfo(), MersenneTwister(123456)
);
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand Down
3 changes: 2 additions & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Random.seed!(2)
iterations = 500
model_fun = infiniteGMM(data)

m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo())
rng = MersenneTwister(123456)
m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
f = m.evaluator[1]
args = m.evaluator[2:end]

Expand Down
31 changes: 24 additions & 7 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function (tf::TapedFunction)(args...; callback=nothing, continuation=false)
# run the raw tape
while true
ins = tf.tape[tf.counter]
ins(tf)
ins(tf, callback)
callback !== nothing && callback()
tf.retval_binding_slot != 0 && break
end
Expand Down Expand Up @@ -211,12 +211,17 @@ function Base.show(io::IO, instr::CondGotoInstruction)
println(io, "CondGotoInstruction(", instr.condition, ", dest=", instr.dest, ")")
end

function (instr::Instruction{F})(tf::TapedFunction) where F
function (instr::Instruction{F})(tf::TapedFunction, callback=nothing) where F
# catch run-time exceptions / errors.
try
func = F === Int ? _lookup(tf, instr.func) : instr.func
inputs = map(x -> _lookup(tf, x), instr.input)
output = func(inputs...)
output = if is_primitive(func, inputs...)
func(inputs...)
else
tf_inner = TapedFunction(func, inputs..., cache=true)
tf_inner(inputs...; callback=callback)
end
_update_var!(tf, instr.output, output)
tf.counter += 1
catch e
Expand All @@ -227,11 +232,11 @@ function (instr::Instruction{F})(tf::TapedFunction) where F
end
end

function (instr::GotoInstruction)(tf::TapedFunction)
function (instr::GotoInstruction)(tf::TapedFunction, callback=nothing)
tf.counter = instr.dest
end

function (instr::CondGotoInstruction)(tf::TapedFunction)
function (instr::CondGotoInstruction)(tf::TapedFunction, callback=nothing)
cond = _lookup(tf, instr.condition)
if cond
tf.counter += 1
Expand All @@ -240,11 +245,11 @@ function (instr::CondGotoInstruction)(tf::TapedFunction)
end
end

function (instr::ReturnInstruction)(tf::TapedFunction)
function (instr::ReturnInstruction)(tf::TapedFunction, callback=nothing)
tf.retval_binding_slot = instr.arg
end

function (instr::NOOPInstruction)(tf::TapedFunction)
function (instr::NOOPInstruction)(tf::TapedFunction, callback=nothing)
tf.counter += 1
end

Expand Down Expand Up @@ -450,6 +455,18 @@ function translate!!(var, line, bindings, isconst, ir)
throw(ErrorException("Unknown IR code"))
end

## primitives.

"""
is_primitive(f, args...)

Should a function be recursed into, or should it be treated as a single instruction, when
encountered inside of a `TapedFunction`. If `is_primitive(f, args...)` is `true`, then
the instruction will not be traced into. Conversely, if `is_primitive(f, args...)` is
`false`, a `TapedFunction` is constructed.
"""
is_primitive(f, args...) = true

## copy Bindings, TapedFunction

"""
Expand Down
11 changes: 11 additions & 0 deletions test/tf.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
using Libtask

foo(x) = sin(cos(x))
bar(x) = foo(foo(x))

Libtask.is_primitive(::typeof(foo), args...) = false

@testset "tapedfunction" begin
# Test case 1: stack allocated objects are deep copied.
@testset "Instruction{typeof(__new__)}" begin
Expand Down Expand Up @@ -31,4 +36,10 @@ using Libtask

@test typeof(r) === Float64
end
@testset "recurse into function" begin
tf = Libtask.TapedFunction(bar, 5.0)
count = 0
tf(4.0; callback=() -> (count += 1))
@test count == 9
end
end
Loading