Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Ready for review): Switch combinator #334

Merged
merged 31 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3e4f695
Initial work on a Switch combinator.
femtomc Nov 17, 2020
bd4f830
Initial implementation of propose and generate.
femtomc Nov 17, 2020
374a7b0
Added implementaton of simulate.
femtomc Nov 17, 2020
5872593
Corrected some bugs with Bernoulli vs bernoulli.
femtomc Nov 17, 2020
9c0a9f2
Added assess implementation.
femtomc Nov 17, 2020
95baf07
Split into two combinators: Switch and WithProbability implementations.
femtomc Nov 18, 2020
29b7797
Working on Switch update and regenerate.
femtomc Nov 18, 2020
3e6e307
Added Switch update and regenerate.
femtomc Nov 18, 2020
7929b86
Added Switch update and regenerate - working out kinks in update.
femtomc Nov 18, 2020
73618a1
update and regenerate appear to be computing the correct ratios. To c…
femtomc Nov 18, 2020
252413f
Fixed generate index type bug.
femtomc Nov 18, 2020
ac3528e
Branch dispatch done using diff types.
femtomc Nov 18, 2020
eaf3327
Branch dispatch done using diff types.
femtomc Nov 18, 2020
6d58aac
Branch dispatch done using diff types.
femtomc Nov 18, 2020
e413e9c
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
435493f
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
32fec4f
Idiomatic check for EmptyChoiceMap.
femtomc Nov 18, 2020
bb767e7
Working on backprop - seems simple? Could it really be?
femtomc Nov 18, 2020
a35e2e7
Extracting WithProb combinator into another PR.
femtomc Nov 18, 2020
562667e
Testing backprop.
femtomc Nov 19, 2020
b74a071
Fixed backprop - was thinking in Zygote lang. Gradients appear to be …
femtomc Nov 19, 2020
915811d
Merge branch 'master' of https://github.com/probcomp/Gen.jl into 2020…
femtomc Nov 19, 2020
849d61e
Added docstring and docs example.
femtomc Nov 19, 2020
adf73a5
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 19, 2020
dfe0125
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 20, 2020
3717d65
Tests for everything but gradients - working on gradients now.
femtomc Nov 20, 2020
cb62fb5
Last tests I need to write: accumulate_param_gradients!
femtomc Nov 20, 2020
97473d0
Added accumulate_param_gradients! tests.
femtomc Nov 20, 2020
176b9e9
Reverted particle filter fix - will be handled in another issue.
femtomc Nov 20, 2020
0465965
Renamed mix field of Switch generative function to branches to more a…
femtomc Nov 22, 2020
43c7274
Addressed review comments. Added docstrings where necessary. Correcte…
femtomc Dec 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions scratch/switch_comb.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
module SwitchComb

include("../src/Gen.jl")
using .Gen

@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + y, std), :z)
return z
end

@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + 2 * y, std), :z)
return z
end

sc = Switch(foo, baz)
chm, _, _ = propose(sc, (0.3, 5.0, 3.0))
display(chm)

tr = simulate(sc, (0.3, 5.0, 3.0))
display(get_choices(tr))

chm = choicemap()
chm[:cond] = true
tr, _ = generate(sc, (0.3, 5.0, 3.0), chm)
display(get_choices(tr))

end # module
4 changes: 4 additions & 0 deletions src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ include("dist_dsl/dist_dsl.jl")
# code shared by vector-shaped combinators
include("vector.jl")

# trace for switch combinator
include("switch/trace.jl")

# built-in generative function combinators
include("choice_at/choice_at.jl")
include("call_at/call_at.jl")
include("map/map.jl")
include("unfold/unfold.jl")
include("recurse/recurse.jl")
include("switch/switch.jl")

#############################################################
# abstractions for constructing custom generative functions #
Expand Down
26 changes: 26 additions & 0 deletions src/modeling_library/switch/assess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mutable struct SwitchAssessState{T}
weight::Float64
retval::T
end

function process_new!(gen_fn::Switch{T1, T2, Tr},
branch_p::Float64,
args::Tuple,
choices::ChoiceMap,
state::SwitchAssessState{Union{T1, T2}}) where {T1, T2, Tr}
flip = get_value(choices, :cond)
state.weight += logpdf(Bernoulli(), flip, branch_p)
submap = get_submap(choices, :branch)
(weight, retval) = assess(gen_fn.kernel, kernel_args, submap)
state.weight += weight
state.retval = retval
end

function assess(gen_fn::Switch{T1, T2, Tr},
args::Tuple,
choices::ChoiceMap) where {T1, T2, Tr}
branch_p = args[1]
state = SwitchAssessState{Union{T1, T2}}(0.0)
process_new!(gen_fn, branch_p, args[2 : end], choices, state)
(state.weight, state.retval)
end
47 changes: 47 additions & 0 deletions src/modeling_library/switch/generate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
mutable struct SwitchGenerateState{T1, T2, Tr}
score::Float64
noise::Float64
weight::Float64
cond::Bool
subtrace::Tr
retval::Union{T1, T2}
SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise, weight)
end

function process!(gen_fn::Switch{T1, T2, Tr},
branch_p::Float64,
args::Tuple,
choices::ChoiceMap,
state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr}

# create flip distribution
flip_d = bernoulli(branch_p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@femtomc This samples from a Bernoulli distribution with probability branch_p and returns a Bool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies - that comment needs to be removed. I figured that bit out and corrected it elsewhere, but the comments indicate otherwise.


# check for constraints at :cond
constrained = has_value(choices, :cond)
!constrained && check_no_submap(choices, :cond)

# get/constrain flip value
constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d)
state.cond = flip

# generate subtrace
constraints = get_submap(choices, :branch)
(subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints)
state.subtrace = subtrace
state.weight += weight

# return from branch
state.retval = get_retval(subtrace)
end

function generate(gen_fn::Switch{T1, T2, Tr},
args::Tuple,
choices::ChoiceMap) where {T1, T2, Tr}

branch_p = args[1]
state = SwitchGenerateState{T1, T2, Tr}(0.0, 0.0, 0.0)
process!(gen_fn, branch_p, args[2 : end], choices, state)
trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise)
(trace, state.weight)
end
30 changes: 30 additions & 0 deletions src/modeling_library/switch/propose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
mutable struct SwitchProposeState{T}
choices::DynamicChoiceMap
weight::Float64
retval::T
SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight)
end

function process_new!(gen_fn::Switch{T1, T2, Tr},
branch_p::Float64,
args::Tuple,
state::SwitchProposeState{Union{T1, T2}}) where {T1, T2, Tr}

flip = bernoulli(branch_p)
(submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args)
set_value!(state.choices, :cond, flip)
state.weight += logpdf(Bernoulli(), flip, branch_p)
set_submap!(state.choices, :branch, submap)
state.weight += weight
state.retval = retval
end

function propose(gen_fn::Switch{T1, T2, Tr},
args::Tuple) where {T1, T2, Tr}

branch_p = args[1]
choices = choicemap()
state = SwitchProposeState{Union{T1, T2}}(choices, 0.0)
process_new!(gen_fn, branch_p, args[2:end], state)
(state.choices, state.weight, state.retval)
end
34 changes: 34 additions & 0 deletions src/modeling_library/switch/simulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
mutable struct SwitchSimulateState{T1, T2, Tr}
score::Float64
noise::Float64
cond::Bool
subtrace::Tr
retval::Union{T1, T2}
SwitchSimulateState{T1, T2, Tr}(score::Float64, noise::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise)
end

function process!(gen_fn::Switch{T1, T2, Tr},
branch_p::Float64,
args::Tuple,
state::SwitchSimulateState{T1, T2, Tr}) where {T1, T2, Tr}
local subtrace::Tr
local retval::Union{T1, T2}
flip = bernoulli(branch_p)
state.score += logpdf(Bernoulli(), flip, branch_p)
state.cond = flip
subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args)
state.noise += project(subtrace, EmptySelection())
state.subtrace = subtrace
state.score += get_score(subtrace)
state.retval = get_retval(subtrace)
end

function simulate(gen_fn::Switch{T1, T2, Tr},
args::Tuple) where {T1, T2, Tr}

branch_p = args[1]
state = SwitchSimulateState{T1, T2, Tr}(0.0, 0.0)
process!(gen_fn, branch_p, args[2 : end], state)
trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise)
trace
end
19 changes: 19 additions & 0 deletions src/modeling_library/switch/switch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, Tr}
a::GenerativeFunction{T1, Tr}
b::GenerativeFunction{T2, Tr}
end

export Switch

has_argument_grads(switch_fn::Switch) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b)
accepts_output_grad(switch_fn::Switch) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b)

function (gen_fn::Switch)(flip_p::Float64, args...)
(_, _, retval) = propose(gen_fn, (flip_p, args...))
retval
end

include("assess.jl")
include("propose.jl")
include("simulate.jl")
include("generate.jl")
38 changes: 38 additions & 0 deletions src/modeling_library/switch/trace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
struct SwitchTrace{T1, T2, Tr} <: Trace
kernel::GenerativeFunction{Union{T1, T2}, Tr}
p::Float64
cond::Bool
branch::Tr
retval::Union{T1, T2}
args::Tuple
score::Float64
noise::Float64
end

@inline function get_choices(tr::SwitchTrace)
choices = choicemap()
set_submap!(choices, :branch, get_choices(tr.branch))
set_value!(choices, :cond, tr.cond)
choices
end
@inline get_retval(tr::SwitchTrace) = tr.retval
@inline get_args(tr::SwitchTrace) = tr.args
@inline get_score(tr::SwitchTrace) = tr.score
@inline get_gen_fn(tr::SwitchTrace) = tr.kernel

@inline function Base.getindex(tr::SwitchTrace, addr::Pair)
(first, rest) = addr
subtr = getfield(trace, first)
subtrace[rest]
end
@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr)

function project(tr::SwitchTrace, selection::Selection)
weight = 0.
for k in [:cond, :branch]
subselection = selection[k]
weight += project(getindex(tr, k), subselection)
end
weight
end
project(tr::SwitchTrace, ::EmptySelection) = tr.noise
44 changes: 44 additions & 0 deletions test/modeling_library/switch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
@testset "switch combinator" begin

@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64))
@param std::Float64
z = @trace(normal(x + y, std), :z)
return z
end

@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64))
@param std::Float64
z = @trace(normal(x + 2 * y, std), :z)
return z
end

set_param!(foo, :std, 1.)
set_param!(baz, :std, 1.)

bar = Switch(foo, baz)
args = (1.0, 3.0)

@testset "simulate" begin
end

@testset "generate" begin
end

@testset "propose" begin
end

@testset "assess" begin
end

@testset "update" begin
end

@testset "regenerate" begin
end

@testset "choice_gradients" begin
end

@testset "accumulate_param_gradients!" begin
end
end