-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(Ready for review): Switch combinator #334
Merged
marcoct
merged 31 commits into
probcomp:master
from
femtomc:20201116_mrb_switch_combinator
Dec 8, 2020
Merged
Changes from 5 commits
Commits
Show all changes
31 commits
Select commit
Hold shift + click to select a range
3e4f695
Initial work on a Switch combinator.
femtomc bd4f830
Initial implementation of propose and generate.
femtomc 374a7b0
Added implementaton of simulate.
femtomc 5872593
Corrected some bugs with Bernoulli vs bernoulli.
femtomc 9c0a9f2
Added assess implementation.
femtomc 95baf07
Split into two combinators: Switch and WithProbability implementations.
femtomc 29b7797
Working on Switch update and regenerate.
femtomc 3e6e307
Added Switch update and regenerate.
femtomc 7929b86
Added Switch update and regenerate - working out kinks in update.
femtomc 73618a1
update and regenerate appear to be computing the correct ratios. To c…
femtomc 252413f
Fixed generate index type bug.
femtomc ac3528e
Branch dispatch done using diff types.
femtomc eaf3327
Branch dispatch done using diff types.
femtomc 6d58aac
Branch dispatch done using diff types.
femtomc e413e9c
Added custom methods in update for Switch which allow the merging of …
femtomc 435493f
Added custom methods in update for Switch which allow the merging of …
femtomc 32fec4f
Idiomatic check for EmptyChoiceMap.
femtomc bb767e7
Working on backprop - seems simple? Could it really be?
femtomc a35e2e7
Extracting WithProb combinator into another PR.
femtomc 562667e
Testing backprop.
femtomc b74a071
Fixed backprop - was thinking in Zygote lang. Gradients appear to be …
femtomc 915811d
Merge branch 'master' of https://github.com/probcomp/Gen.jl into 2020…
femtomc 849d61e
Added docstring and docs example.
femtomc adf73a5
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc dfe0125
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc 3717d65
Tests for everything but gradients - working on gradients now.
femtomc cb62fb5
Last tests I need to write: accumulate_param_gradients!
femtomc 97473d0
Added accumulate_param_gradients! tests.
femtomc 176b9e9
Reverted particle filter fix - will be handled in another issue.
femtomc 0465965
Renamed mix field of Switch generative function to branches to more a…
femtomc 43c7274
Addressed review comments. Added docstrings where necessary. Correcte…
femtomc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
module SwitchComb | ||
|
||
include("../src/Gen.jl") | ||
using .Gen | ||
|
||
@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) | ||
std::Float64 = 3.0 | ||
z = @trace(normal(x + y, std), :z) | ||
return z | ||
end | ||
|
||
@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) | ||
std::Float64 = 3.0 | ||
z = @trace(normal(x + 2 * y, std), :z) | ||
return z | ||
end | ||
|
||
sc = Switch(foo, baz) | ||
chm, _, _ = propose(sc, (0.3, 5.0, 3.0)) | ||
display(chm) | ||
|
||
tr = simulate(sc, (0.3, 5.0, 3.0)) | ||
display(get_choices(tr)) | ||
|
||
chm = choicemap() | ||
chm[:cond] = true | ||
tr, _ = generate(sc, (0.3, 5.0, 3.0), chm) | ||
display(get_choices(tr)) | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
mutable struct SwitchAssessState{T} | ||
weight::Float64 | ||
retval::T | ||
end | ||
|
||
function process_new!(gen_fn::Switch{T1, T2, Tr}, | ||
branch_p::Float64, | ||
args::Tuple, | ||
choices::ChoiceMap, | ||
state::SwitchAssessState{Union{T1, T2}}) where {T1, T2, Tr} | ||
flip = get_value(choices, :cond) | ||
state.weight += logpdf(Bernoulli(), flip, branch_p) | ||
submap = get_submap(choices, :branch) | ||
(weight, retval) = assess(gen_fn.kernel, kernel_args, submap) | ||
state.weight += weight | ||
state.retval = retval | ||
end | ||
|
||
function assess(gen_fn::Switch{T1, T2, Tr}, | ||
args::Tuple, | ||
choices::ChoiceMap) where {T1, T2, Tr} | ||
branch_p = args[1] | ||
state = SwitchAssessState{Union{T1, T2}}(0.0) | ||
process_new!(gen_fn, branch_p, args[2 : end], choices, state) | ||
(state.weight, state.retval) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
mutable struct SwitchGenerateState{T1, T2, Tr} | ||
score::Float64 | ||
noise::Float64 | ||
weight::Float64 | ||
cond::Bool | ||
subtrace::Tr | ||
retval::Union{T1, T2} | ||
SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise, weight) | ||
end | ||
|
||
function process!(gen_fn::Switch{T1, T2, Tr}, | ||
branch_p::Float64, | ||
args::Tuple, | ||
choices::ChoiceMap, | ||
state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} | ||
|
||
# create flip distribution | ||
flip_d = bernoulli(branch_p) | ||
|
||
# check for constraints at :cond | ||
constrained = has_value(choices, :cond) | ||
!constrained && check_no_submap(choices, :cond) | ||
|
||
# get/constrain flip value | ||
constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d) | ||
state.cond = flip | ||
|
||
# generate subtrace | ||
constraints = get_submap(choices, :branch) | ||
(subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) | ||
state.subtrace = subtrace | ||
state.weight += weight | ||
|
||
# return from branch | ||
state.retval = get_retval(subtrace) | ||
end | ||
|
||
function generate(gen_fn::Switch{T1, T2, Tr}, | ||
args::Tuple, | ||
choices::ChoiceMap) where {T1, T2, Tr} | ||
|
||
branch_p = args[1] | ||
state = SwitchGenerateState{T1, T2, Tr}(0.0, 0.0, 0.0) | ||
process!(gen_fn, branch_p, args[2 : end], choices, state) | ||
trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) | ||
(trace, state.weight) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
mutable struct SwitchProposeState{T} | ||
choices::DynamicChoiceMap | ||
weight::Float64 | ||
retval::T | ||
SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) | ||
end | ||
|
||
function process_new!(gen_fn::Switch{T1, T2, Tr}, | ||
branch_p::Float64, | ||
args::Tuple, | ||
state::SwitchProposeState{Union{T1, T2}}) where {T1, T2, Tr} | ||
|
||
flip = bernoulli(branch_p) | ||
(submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) | ||
set_value!(state.choices, :cond, flip) | ||
state.weight += logpdf(Bernoulli(), flip, branch_p) | ||
set_submap!(state.choices, :branch, submap) | ||
state.weight += weight | ||
state.retval = retval | ||
end | ||
|
||
function propose(gen_fn::Switch{T1, T2, Tr}, | ||
args::Tuple) where {T1, T2, Tr} | ||
|
||
branch_p = args[1] | ||
choices = choicemap() | ||
state = SwitchProposeState{Union{T1, T2}}(choices, 0.0) | ||
process_new!(gen_fn, branch_p, args[2:end], state) | ||
(state.choices, state.weight, state.retval) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
mutable struct SwitchSimulateState{T1, T2, Tr} | ||
score::Float64 | ||
noise::Float64 | ||
cond::Bool | ||
subtrace::Tr | ||
retval::Union{T1, T2} | ||
SwitchSimulateState{T1, T2, Tr}(score::Float64, noise::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise) | ||
end | ||
|
||
function process!(gen_fn::Switch{T1, T2, Tr}, | ||
branch_p::Float64, | ||
args::Tuple, | ||
state::SwitchSimulateState{T1, T2, Tr}) where {T1, T2, Tr} | ||
local subtrace::Tr | ||
local retval::Union{T1, T2} | ||
flip = bernoulli(branch_p) | ||
state.score += logpdf(Bernoulli(), flip, branch_p) | ||
state.cond = flip | ||
subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) | ||
state.noise += project(subtrace, EmptySelection()) | ||
state.subtrace = subtrace | ||
state.score += get_score(subtrace) | ||
state.retval = get_retval(subtrace) | ||
end | ||
|
||
function simulate(gen_fn::Switch{T1, T2, Tr}, | ||
args::Tuple) where {T1, T2, Tr} | ||
|
||
branch_p = args[1] | ||
state = SwitchSimulateState{T1, T2, Tr}(0.0, 0.0) | ||
process!(gen_fn, branch_p, args[2 : end], state) | ||
trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) | ||
trace | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, Tr} | ||
a::GenerativeFunction{T1, Tr} | ||
b::GenerativeFunction{T2, Tr} | ||
end | ||
|
||
export Switch | ||
|
||
has_argument_grads(switch_fn::Switch) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b) | ||
accepts_output_grad(switch_fn::Switch) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b) | ||
|
||
function (gen_fn::Switch)(flip_p::Float64, args...) | ||
(_, _, retval) = propose(gen_fn, (flip_p, args...)) | ||
retval | ||
end | ||
|
||
include("assess.jl") | ||
include("propose.jl") | ||
include("simulate.jl") | ||
include("generate.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
struct SwitchTrace{T1, T2, Tr} <: Trace | ||
kernel::GenerativeFunction{Union{T1, T2}, Tr} | ||
p::Float64 | ||
cond::Bool | ||
branch::Tr | ||
retval::Union{T1, T2} | ||
args::Tuple | ||
score::Float64 | ||
noise::Float64 | ||
end | ||
|
||
@inline function get_choices(tr::SwitchTrace) | ||
choices = choicemap() | ||
set_submap!(choices, :branch, get_choices(tr.branch)) | ||
set_value!(choices, :cond, tr.cond) | ||
choices | ||
end | ||
@inline get_retval(tr::SwitchTrace) = tr.retval | ||
@inline get_args(tr::SwitchTrace) = tr.args | ||
@inline get_score(tr::SwitchTrace) = tr.score | ||
@inline get_gen_fn(tr::SwitchTrace) = tr.kernel | ||
|
||
@inline function Base.getindex(tr::SwitchTrace, addr::Pair) | ||
(first, rest) = addr | ||
subtr = getfield(trace, first) | ||
subtrace[rest] | ||
end | ||
@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) | ||
|
||
function project(tr::SwitchTrace, selection::Selection) | ||
weight = 0. | ||
for k in [:cond, :branch] | ||
subselection = selection[k] | ||
weight += project(getindex(tr, k), subselection) | ||
end | ||
weight | ||
end | ||
project(tr::SwitchTrace, ::EmptySelection) = tr.noise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
@testset "switch combinator" begin | ||
|
||
@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) | ||
@param std::Float64 | ||
z = @trace(normal(x + y, std), :z) | ||
return z | ||
end | ||
|
||
@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) | ||
@param std::Float64 | ||
z = @trace(normal(x + 2 * y, std), :z) | ||
return z | ||
end | ||
|
||
set_param!(foo, :std, 1.) | ||
set_param!(baz, :std, 1.) | ||
|
||
bar = Switch(foo, baz) | ||
args = (1.0, 3.0) | ||
|
||
@testset "simulate" begin | ||
end | ||
|
||
@testset "generate" begin | ||
end | ||
|
||
@testset "propose" begin | ||
end | ||
|
||
@testset "assess" begin | ||
end | ||
|
||
@testset "update" begin | ||
end | ||
|
||
@testset "regenerate" begin | ||
end | ||
|
||
@testset "choice_gradients" begin | ||
end | ||
|
||
@testset "accumulate_param_gradients!" begin | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@femtomc This samples from a Bernoulli distribution with probability branch_p and returns a Bool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies - that comment needs to be removed. I figured that bit out and corrected it elsewhere, but the comments indicate otherwise.