From 748823833478d7a7e87a52d8626c7b2ead352b12 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 14:58:14 +0100 Subject: [PATCH 1/7] fixed incorrect implementation of `dot_tilde_assume` for `PrefixContext` --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f3c5171b0..d7c24fafb 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -402,12 +402,12 @@ end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) + return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), vi + rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) end From 1d211c5c8ef5bda879cb9a09b2838b1faa2545c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 14:58:33 +0100 Subject: [PATCH 2/7] removed `vars` field from `PriorContext` and `LikelihoodContext` as it's no longer used functionality (was dropped when we dropped the logprob-macro) --- src/context_implementations.jl | 101 --------------------------------- src/contexts.jl | 29 +++------- 2 files changed, 8 insertions(+), 122 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7c24fafb..50919e77e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -77,44 +77,6 @@ function tilde_assume( return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(PriorContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) -end - -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - vi, -) - if haskey(context.vars, getsym(vn)) - vi = setindex!!(vi, tovec(get(context.vars, vn)), vn) - settrans!!(vi, false, vn) - end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) -end function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end @@ -328,37 +290,6 @@ function dot_tilde_assume( end # `LikelihoodContext` -function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) - end -end - function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(nodist(right), left, vn, vi) end @@ -368,38 +299,6 @@ function dot_tilde_assume( return dot_assume(rng, sampler, nodist(right), vn, left, vi) end -# `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = get(context.vars, vn) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!!.((vi,), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) - else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) - end -end - # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) diff --git a/src/contexts.jl b/src/contexts.jl index 53b454df6..5da4208b5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -53,7 +53,7 @@ DefaultContext() julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) -PriorContext{Nothing}(nothing) +PriorContext() ``` """ setchildcontext @@ -97,7 +97,7 @@ ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. leafcontext(setleafcontext(ctx, PriorContext())) -PriorContext{Nothing}(nothing) +PriorContext() julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) @@ -195,32 +195,19 @@ struct DefaultContext <: AbstractContext end NodeTrait(context::DefaultContext) = IsLeaf() """ - struct PriorContext{Tvars} <: AbstractContext - vars::Tvars - end + PriorContext <: AbstractContext -The `PriorContext` enables the computation of the log prior of the parameters `vars` when -running the model. +A leaf context resulting in the exclusion of likelihood terms when running the model. """ -struct PriorContext{Tvars} <: AbstractContext - vars::Tvars -end -PriorContext() = PriorContext(nothing) +struct PriorContext <: AbstractContext end NodeTrait(context::PriorContext) = IsLeaf() """ - struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars - end + LikelihoodContext <: AbstractContext -The `LikelihoodContext` enables the computation of the log likelihood of the parameters when -running the model. `vars` can be used to evaluate the log likelihood for specific values -of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. +A leaf context resulting in the exclusion of prior terms when running the model. """ -struct LikelihoodContext{Tvars} <: AbstractContext - vars::Tvars -end -LikelihoodContext() = LikelihoodContext(nothing) +struct LikelihoodContext <: AbstractContext end NodeTrait(context::LikelihoodContext) = IsLeaf() """ From 8e7d164644b000a85b86df23a9a6a755e17f6ecd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 16:02:38 +0100 Subject: [PATCH 3/7] added `dot_tilde_assume` overloads for `FixedContext` to handle the cases where current `fix` is failiing --- src/context_implementations.jl | 53 +++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50919e77e..0d6e4bdf4 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -304,12 +304,63 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) +function dot_tilde_assume(rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) end +# `FixedContext` +function dot_tilde_assume(context::FixedContext, right, left, vns, vi) + # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. + # So we need to check each of the vns. + logp = 0 + # TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`. + # If the `Symbol` is not present, we can just skip this check completely. Such a check can + # then be compiled away in cases where the `Symbol` is not present. + left_bc = Broadcast.broadcastable(left) + right_bc = Broadcast.broadcastable(right) + for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...) + for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...) + vn = vns[I_left...] + if hasfixed(context, vn) + left[I_left...] = getfixed(context, vn) + else + # Defer to `tilde_assume`. + left[I_left...], logp_inner, vi = tilde_assume(context, right_bc[I_right...], vn, vi) + logp += logp_inner + end + end + end + + return left, logp, vi +end + +function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi) + # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. + # So we need to check each of the vns. + logp = 0 + # TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`. + # If the `Symbol` is not present, we can just skip this check completely. Such a check can + # then be compiled away in cases where the `Symbol` is not present. + left_bc = Broadcast.broadcastable(left) + right_bc = Broadcast.broadcastable(right) + for I_left in Iterators.product(Broadcast.broadcast_axes(left_bc)...) + for I_right in Iterators.product(Broadcast.broadcast_axes(right_bc)...) + vn = vns[I_left...] + if hasfixed(context, vn) + left[I_left...] = getfixed(context, vn) + else + # Defer to `tilde_assume`. + left[I_left...], logp_inner, vi = tilde_assume(rng, context, sampler, right_bc[I_right...], vn, vi) + logp += logp_inner + end + end + end + + return left, logp, vi +end + """ dot_tilde_assume!!(context, right, left, vn, vi) From 86fe1c6e94e72105c13711249ef1592d4a0fba50 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 16:03:28 +0100 Subject: [PATCH 4/7] added error-handling of invalid broadcasting statements --- src/compiler.jl | 11 +++++++++++ test/compiler.jl | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 90220cbf5..0bff04774 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -235,6 +235,17 @@ function unwrap_right_left_vns( left::AbstractArray, vn::VarName, ) + # Need to check that we don't end up double-counting log-probabilities. + combined_axes = Broadcast.combine_axes(left, right) + if prod(length, combined_axes) > length(left) + throw( + ArgumentError( + "a `.~` statement cannot result in a broadcasted expression with more elements than the left-hand side", + ), + ) + end + + # Extract the sub-varnames. vns = map(CartesianIndices(left)) do i return Accessors.IndexLens(Tuple(i)) ∘ vn end diff --git a/test/compiler.jl b/test/compiler.jl index f2d7e5852..d4f38f5a6 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -729,4 +729,13 @@ module Issue537 end res = model() @test res == (a=1, b=1, c=2, d=2, t=DynamicPPL.TypeWrap{Int}()) end + + @testset "invalid .~ expressions" begin + @model function demo_with_invalid_dot_tilde() + m = Matrix{Float64}(undef, 1, 2) + m .~ [Normal(); Normal()] + end + + @test_throws ArgumentError demo_with_invalid_dot_tilde()() + end end From 0b7ba4ba2708cdfaf0e0809ff5bbd6d1051f8ae4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 16:05:54 +0100 Subject: [PATCH 5/7] added static checking to avoid the slow fixed branches unless we really need to --- src/context_implementations.jl | 24 ++++++++++++++++++++---- src/contexts.jl | 7 +++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0d6e4bdf4..bf7291959 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -312,8 +312,13 @@ end # `FixedContext` function dot_tilde_assume(context::FixedContext, right, left, vns, vi) + if !has_fixed_symbol(context, first(vns)) + # Defer to `childcontext`. + return tilde_assume(childcontext(context), right, left, vns, vi) + end + # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. - # So we need to check each of the vns. + # We _might_ also have some of the variables fixed, but not all. logp = 0 # TODO(torfjelde): Add a check to see if the `Symbol` of `vns` exists in `FixedContext`. # If the `Symbol` is not present, we can just skip this check completely. Such a check can @@ -327,7 +332,9 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi) left[I_left...] = getfixed(context, vn) else # Defer to `tilde_assume`. - left[I_left...], logp_inner, vi = tilde_assume(context, right_bc[I_right...], vn, vi) + left[I_left...], logp_inner, vi = tilde_assume( + childcontext(context), right_bc[I_right...], vn, vi + ) logp += logp_inner end end @@ -336,7 +343,14 @@ function dot_tilde_assume(context::FixedContext, right, left, vns, vi) return left, logp, vi end -function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi) +function dot_tilde_assume( + rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi +) + + if !has_fixed_symbol(context, first(vns)) + # Defer to `childcontext`. + return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi) + end # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. # So we need to check each of the vns. logp = 0 @@ -352,7 +366,9 @@ function dot_tilde_assume(rng::Random.AbstractRNG, context::FixedContext, sample left[I_left...] = getfixed(context, vn) else # Defer to `tilde_assume`. - left[I_left...], logp_inner, vi = tilde_assume(rng, context, sampler, right_bc[I_right...], vn, vi) + left[I_left...], logp_inner, vi = tilde_assume( + rng, childcontext(context), sampler, right_bc[I_right...], vn, vi + ) logp += logp_inner end end diff --git a/src/contexts.jl b/src/contexts.jl index 5da4208b5..11ab2b5a8 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -501,6 +501,13 @@ NodeTrait(::FixedContext) = IsParent() childcontext(context::FixedContext) = context.context setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child) +has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn) + +has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn) +@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names, sym} + return sym in names +end + """ hasfixed(context::AbstractContext, vn::VarName) From e66e6ae75d03ec4bc05b13f9adafa8801b585a24 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Nov 2024 16:56:51 +0100 Subject: [PATCH 6/7] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 5 +++-- src/contexts.jl | 2 +- test/compiler.jl | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index bf7291959..c19858816 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -304,7 +304,9 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi) +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PrefixContext, sampler, right, left, vn, vi +) return dot_tilde_assume( rng, context.context, sampler, right, left, prefix.(Ref(context), vn), vi ) @@ -346,7 +348,6 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::FixedContext, sampler, right, left, vns, vi ) - if !has_fixed_symbol(context, first(vns)) # Defer to `childcontext`. return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 11ab2b5a8..bd04d8783 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -504,7 +504,7 @@ setchildcontext(parent::FixedContext, child) = FixedContext(parent.values, child has_fixed_symbol(context::FixedContext, vn::VarName) = has_symbol(context.values, vn) has_symbol(d::AbstractDict, vn::VarName) = haskey(d, vn) -@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names, sym} +@generated function has_symbol(::NamedTuple{names}, ::VarName{sym}) where {names,sym} return sym in names end diff --git a/test/compiler.jl b/test/compiler.jl index d4f38f5a6..8a11021fe 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -733,7 +733,7 @@ module Issue537 end @testset "invalid .~ expressions" begin @model function demo_with_invalid_dot_tilde() m = Matrix{Float64}(undef, 1, 2) - m .~ [Normal(); Normal()] + return m .~ [Normal(); Normal()] end @test_throws ArgumentError demo_with_invalid_dot_tilde()() From da6f9a015b4cabb6b5fec0696a66fdfa3e4cf71d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Nov 2024 11:11:14 +0100 Subject: [PATCH 7/7] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index c19858816..fcb8add88 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -316,7 +316,7 @@ end function dot_tilde_assume(context::FixedContext, right, left, vns, vi) if !has_fixed_symbol(context, first(vns)) # Defer to `childcontext`. - return tilde_assume(childcontext(context), right, left, vns, vi) + return dot_tilde_assume(childcontext(context), right, left, vns, vi) end # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. @@ -350,7 +350,7 @@ function dot_tilde_assume( ) if !has_fixed_symbol(context, first(vns)) # Defer to `childcontext`. - return tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi) + return dot_tilde_assume(rng, childcontext(context), sampler, right, left, vns, vi) end # If we're reached here, then we didn't hit the initial `getfixed` call in the model body. # So we need to check each of the vns.