From 1388502dd372476d808f1689a4773a0d5b854175 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 14 Jul 2021 01:12:12 +0100 Subject: [PATCH 01/32] show full timings for evaluation rather than just min --- benchmarks/benchmark_body.jmd | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index f9c994dc9..9d9810dc2 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -8,8 +8,15 @@ m = time_model_def(model_def, data); ```julia suite = make_suite(m); -results = run(suite) -results +results = run(suite); +``` + +```julia +results["evaluation_untyped"] +``` + +```julia +results["evaluation_typed"] ``` ```julia; echo=false; results="hidden"; From 678ef1dfb3422cf3c3b4a49b699f7dc656ccf458 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 27 Jul 2021 21:10:14 +0100 Subject: [PATCH 02/32] initial work on allowing more than just real and array variables --- Project.toml | 1 + src/DynamicPPL.jl | 2 + src/compiler.jl | 152 ++++++++++++++++++++++++--------- src/context_implementations.jl | 136 ++++++++++++++--------------- 4 files changed, 182 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index 5678c050a..adb1b7043 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..344a92594 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -10,6 +10,8 @@ using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using Setfield: Setfield + using Random: Random import Base: diff --git a/src/compiler.jl b/src/compiler.jl index 91fe78e2b..f2b84cfbd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,7 @@ isassumption(expr) = :(false) # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x -maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@view($x))) +maybe_view(x::Expr) = :($(DynamicPPL.maybe_unwrap_view)(@views($x))) # If the result of a `view` is a zero-dim array then it's just a # single element. Likely the rest is expecting type `eltype(x)`, hence @@ -267,6 +267,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Do we don't want escaped expressions because we unfortunately + # escape the entire body afterwards. + Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) @@ -299,6 +303,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn) return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end +function generate_tilde_literal(left, right) + # If the LHS is a literal, it is always an observation + return quote + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + ) + end +end + """ generate_tilde(left, right) @@ -306,88 +319,145 @@ Generate an `observe` expression for data variables and `assume` expression for variables. """ function generate_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption + return quote - $vn = $(varname(left)) - $inds = $(vinds(left)) - $isassumption = $(DynamicPPL.isassumption(left)) + $vn = $(remove_escape(varname(left))) + $isassumption = $(remove_escape(DynamicPPL.isassumption(left))) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_tilde_assume(left, right, vn)) else $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_tilde_assume(left::Symbol, right, vn) + return quote + $left = $(DynamicPPL.tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., + __varinfo__, + ) + end +end + +function generate_tilde_assume(left::Expr, right, vn) + expr = :( + $left = $(DynamicPPL.tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., + __varinfo__, + ) + ) + + return remove_escape(setmacro(identity, expr, overwrite=true)) +end + """ generate_dot_tilde(left, right) Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption return quote $vn = $(varname(left)) - $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_dot_tilde_assume(left, right, vn)) else $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_dot_tilde_assume(left::Symbol, right, vn) + return :( + $left .= $(DynamicPPL.dot_tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn + )..., + __varinfo__, + ) + ) +end + +function generate_dot_tilde_assume(left::Expr, right, vn) + expr = :( + $left .= $(DynamicPPL.dot_tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn + )..., + __varinfo__, + ) + ) + + return remove_escape(setmacro(identity, expr, overwrite=true)) +end + +# HACK: This is unfortunate. It's a consequence of the fact that in +# DynamicPPL we the entire function body. Instead we should be +# more selective with our escape. Until that's the case, we remove them all. +remove_escape(x) = x +function remove_escape(expr::Expr) + Meta.isexpr(expr, :escape) && return remove_escape(expr.args[1]) + return Expr(expr.head, map(x -> remove_escape(x), expr.args)...) +end + +# TODO: Make PR to Setfield.jl to use `gensym` for the `lens` variable. +# This seems like it should be the case anyways since it allows multiple +# calls to `setmacro` without any cost to the current functionality. +function setmacro(lenstransform, ex::Expr; overwrite::Bool=false) + @assert ex.head isa Symbol + @assert length(ex.args) == 2 + ref, val = ex.args + obj, lens = Setfield.parse_obj_lens(ref) + lens_var = gensym("lens") + dst = overwrite ? obj : gensym("_") + val = esc(val) + ret = if ex.head == :(=) + quote + $lens_var = ($lenstransform)($lens) + $dst = $(Setfield.set)($obj, $lens_var, $val) + end + else + op = get_update_op(ex.head) + f = :($(Setfield._UpdateOp)($op,$val)) + quote + $lens_var = ($lenstransform)($lens) + $dst = $(Setfield.modify)($f, $obj, $lens_var) + end + end + ret +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..d65035644 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -19,7 +19,7 @@ _getindex(x, inds::Tuple{}) = x # assume """ - tilde_assume(context::SamplingContext, right, vn, inds, vi) + tilde_assume(context::SamplingContext, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value with a context associated @@ -27,27 +27,27 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) ``` """ -function tilde_assume(context::SamplingContext, right, vn, inds, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +function tilde_assume(context::SamplingContext, right, vn, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) +tilde_assume(::DefaultContext, right, vn, vi) = assume(right, vn, vi) function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi ) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(PriorContext(), right, vn, inds, vi) + return tilde_assume(PriorContext(), right, vn, vi) end function tilde_assume( rng::Random.AbstractRNG, @@ -62,21 +62,21 @@ function tilde_assume( vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end -function tilde_assume(::PriorContext, right, vn, inds, vi) +function tilde_assume(::PriorContext, right, vn, vi) return assume(right, vn, vi) end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) + return tilde_assume(LikelihoodContext(), right, vn, vi) end function tilde_assume( rng::Random.AbstractRNG, @@ -91,62 +91,62 @@ function tilde_assume( vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) +function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi ) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) +function tilde_assume(context::MiniBatchContext, right, vn, vi) + return tilde_assume(context.context, right, vn, vi) end -function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, vi) + return tilde_assume(rng, context.context, sampler, right, vn, vi) end -function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +function tilde_assume(context::PrefixContext, right, vn, vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) end -function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) - value, logp = tilde_assume(context, right, vn, inds, vi) +function tilde_assume!(context, right, vn, vi) + value, logp = tilde_assume(context, right, vn, vi) acclogp!(vi, logp) return value end # observe """ - tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + tilde_observe(context::SamplingContext, right, left, vname, vi) Handle observed variables with a `context` associated with a sampler. Falls back to ```julia -tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vi) ``` """ -function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) +function tilde_observe(context::SamplingContext, right, left, vname, vi) return tilde_observe( - context.rng, context.context, context.sampler, right, left, vname, vinds, vi + context.rng, context.context, context.sampler, right, left, vname, vi ) end @@ -190,7 +190,7 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -198,7 +198,7 @@ accumulate the log probability, and return the observed value. Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) +function tilde_observe!(context, right, left, vname, vi) return tilde_observe!(context, right, left, vi) end @@ -273,7 +273,7 @@ end # assume """ - dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + dot_tilde_assume(context::SamplingContext, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value for a context @@ -281,36 +281,36 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) ``` """ -function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, inds, vi + context.rng, context.context, context.sampler, right, left, vn, vi ) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) +function dot_tilde_assume(::DefaultContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end # `LikelihoodContext` function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi + context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi ) return if haskey(context.vars, getsym(vn)) var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -328,30 +328,30 @@ function dot_tilde_assume( _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + dot_tilde_assume(PriorContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -369,52 +369,52 @@ function dot_tilde_assume( _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) end end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::PriorContext, right, left, vn, vi) return dot_assume(right, left, vn, vi) end function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, vi ) return dot_assume(rng, sampler, right, vn, left, vi) end # `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, vi) + return dot_tilde_assume(context.context, right, left, vn, vi) end function dot_tilde_assume( - rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi + rng, context::MiniBatchContext, sampler, right, left, vn, vi ) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, vi) end # `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + rng, context.context, sampler, right, prefix.(Ref(context), vn), vi ) end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, vi) + value, logp = dot_tilde_assume(context, right, left, vn, vi) acclogp!(vi, logp) return value end @@ -612,7 +612,7 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -620,7 +620,7 @@ accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) +function dot_tilde_observe!(context, right, left, vn, vi) return dot_tilde_observe!(context, right, left, vi) end From b867d00786a6a3aa3f4239b9d67c3932bfede765 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 23:11:56 +0100 Subject: [PATCH 03/32] ensure that varname uses concretize --- src/compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c1ab85951..2cce7b3a0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(varname(expr)) + let $vn = $(varname(expr, true)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || @@ -330,7 +330,7 @@ function generate_tilde(left, right) @gensym vn isassumption return quote - $vn = $(remove_escape(varname(left))) + $vn = $(remove_escape(varname(left, true))) $isassumption = $(remove_escape(DynamicPPL.isassumption(left))) if $isassumption $(generate_tilde_assume(left, right, vn)) @@ -384,7 +384,7 @@ function generate_dot_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption return quote - $vn = $(varname(left)) + $vn = $(varname(left, true)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) From 6ad5d9501d4df276657c18e140fd7472f69362a1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 23:12:12 +0100 Subject: [PATCH 04/32] update PointwiseLikelihoodContext --- src/loglikelihoods.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2901432d1..c0c2cc913 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -61,19 +61,19 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, vi) + return tilde_assume(context.context, right, vn, vi) end -function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, vi) + return dot_tilde_assume(context.context, right, left, vn, vi) end function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp = tilde_observe(context.context, right, left, vi) @@ -89,7 +89,7 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi # Defer literal `observe` to child-context. return dot_tilde_observe!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. From 4bf663fbfba9c1b3e28099e33c17cafcd6e42614 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 23:12:30 +0100 Subject: [PATCH 05/32] update unwrap_right_left_vns and fix --- src/compiler.jl | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2cce7b3a0..c600f91d4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -90,6 +90,28 @@ left-hand side of a `.~` expression such as `x .~ Normal()`. This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the variables. + +# Examples +```jldoctest +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns +2-element Vector{VarName{:x, Setfield.IndexLens{Tuple{Colon, Int64}}}}: + x[:,1] + x[:,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns +1×2 Matrix{VarName{:x, Setfield.IndexLens{Tuple{Int64, Int64}}}}: + x[1,1] x[1,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns +1×2 Matrix{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Colon}}, Setfield.IndexLens{Tuple{Int64, Int64}}}}}: + x[:][1,1] x[:][1,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns +3-element Vector{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Int64}}, Setfield.IndexLens{Tuple{Int64}}}}}: + x[1][1] + x[1][2] + x[1][3] +``` """ unwrap_right_left_vns(right, left, vns) = right, left, vns function unwrap_right_left_vns(right::NamedDist, left, vns) @@ -103,7 +125,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return VarName(vn, (vn.indexing..., Colon(), Tuple(i))) + return vn ∘ Setfield.IndexLens((Colon(), i)) end return unwrap_right_left_vns(right, left, vns) end @@ -113,7 +135,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return VarName(vn, (vn.indexing..., Tuple(i))) + return vn ∘ Setfield.IndexLens(Tuple(i)) end return unwrap_right_left_vns(right, left, vns) end From e4922c970cffe49b85457e6606b4e13d459c7bb7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 23:24:30 +0100 Subject: [PATCH 06/32] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index c600f91d4..e4e38be27 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -372,9 +372,7 @@ function generate_tilde_assume(left::Symbol, right, vn) return quote $left = $(DynamicPPL.tilde_assume!)( __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., + $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) end @@ -384,14 +382,12 @@ function generate_tilde_assume(left::Expr, right, vn) expr = :( $left = $(DynamicPPL.tilde_assume!)( __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., + $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) ) - return remove_escape(setmacro(identity, expr, overwrite=true)) + return remove_escape(setmacro(identity, expr; overwrite=true)) end """ @@ -445,7 +441,7 @@ function generate_dot_tilde_assume(left::Expr, right, vn) ) ) - return remove_escape(setmacro(identity, expr, overwrite=true)) + return remove_escape(setmacro(identity, expr; overwrite=true)) end # HACK: This is unfortunate. It's a consequence of the fact that in @@ -475,13 +471,13 @@ function setmacro(lenstransform, ex::Expr; overwrite::Bool=false) end else op = get_update_op(ex.head) - f = :($(Setfield._UpdateOp)($op,$val)) + f = :($(Setfield._UpdateOp)($op, $val)) quote $lens_var = ($lenstransform)($lens) $dst = $(Setfield.modify)($f, $obj, $lens_var) end end - ret + return ret end const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} From ab4b3849c80fe0f471e6377675914485936ad011 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 30 Jul 2021 14:21:48 +0100 Subject: [PATCH 07/32] fixed doctest --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index c600f91d4..15bc676c4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -92,7 +92,7 @@ This is used mainly to unwrap `NamedDist` distributions and adjust the indices o variables. # Examples -```jldoctest +```jldoctest; setup=:(using Distributions) julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns 2-element Vector{VarName{:x, Setfield.IndexLens{Tuple{Colon, Int64}}}}: x[:,1] From 26216e382e2fe87d3acfe14e79c7971a77691bf3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:36:49 +0100 Subject: [PATCH 08/32] forgot to remove escaping some places --- src/compiler.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index e3e4e2a9d..dc3fc3a19 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(varname(expr, true)) + let $vn = $(AbstractPPL.drop_escape(varname(expr, true))) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || @@ -351,9 +351,12 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. return quote - $vn = $(remove_escape(varname(left, true))) - $isassumption = $(remove_escape(DynamicPPL.isassumption(left))) + $vn = $(AbstractPPL.drop_escape(varname(left, true))) + $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $(generate_tilde_assume(left, right, vn)) else From 405d52ca38da23a6715715d4fc336753e2d9d8d0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:37:29 +0100 Subject: [PATCH 09/32] removed usage of Setfield.set for .= and some other niceties --- src/compiler.jl | 54 +++++++++++++------------------------------------ 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc3fc3a19..b967b74f0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -371,17 +371,7 @@ function generate_tilde(left, right) end end -function generate_tilde_assume(left::Symbol, right, vn) - return quote - $left = $(DynamicPPL.tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., - __varinfo__, - ) - end -end - -function generate_tilde_assume(left::Expr, right, vn) +function generate_tilde_assume(left, right, vn) expr = :( $left = $(DynamicPPL.tilde_assume!)( __context__, @@ -390,7 +380,11 @@ function generate_tilde_assume(left::Expr, right, vn) ) ) - return remove_escape(setmacro(identity, expr; overwrite=true)) + return if left isa Expr + AbstractPPL.drop_escape(make_set(identity, expr; overwrite=true)) + else + return expr + end end """ @@ -405,7 +399,7 @@ function generate_dot_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption return quote - $vn = $(varname(left, true)) + $vn = $(AbstractPPL.drop_escape(varname(left, true))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) @@ -421,7 +415,10 @@ function generate_dot_tilde(left, right) end end -function generate_dot_tilde_assume(left::Symbol, right, vn) +function generate_dot_tilde_assume(left, right, vn) + # We don't need to use `Setfield.@set` here since + # `.=` is always going to be inplace + needs `left` to + # be something that supports `.=`. return :( $left .= $(DynamicPPL.dot_tilde_assume!)( __context__, @@ -433,33 +430,10 @@ function generate_dot_tilde_assume(left::Symbol, right, vn) ) end -function generate_dot_tilde_assume(left::Expr, right, vn) - expr = :( - $left .= $(DynamicPPL.dot_tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - __varinfo__, - ) - ) - - return remove_escape(setmacro(identity, expr; overwrite=true)) -end - -# HACK: This is unfortunate. It's a consequence of the fact that in -# DynamicPPL we the entire function body. Instead we should be -# more selective with our escape. Until that's the case, we remove them all. -remove_escape(x) = x -function remove_escape(expr::Expr) - Meta.isexpr(expr, :escape) && return remove_escape(expr.args[1]) - return Expr(expr.head, map(x -> remove_escape(x), expr.args)...) -end - # TODO: Make PR to Setfield.jl to use `gensym` for the `lens` variable. # This seems like it should be the case anyways since it allows multiple -# calls to `setmacro` without any cost to the current functionality. -function setmacro(lenstransform, ex::Expr; overwrite::Bool=false) +# calls to `make_set` without any cost to the current functionality. +function make_set(lenstransform, ex::Expr; overwrite::Bool=false) @assert ex.head isa Symbol @assert length(ex.args) == 2 ref, val = ex.args @@ -473,7 +447,7 @@ function setmacro(lenstransform, ex::Expr; overwrite::Bool=false) $dst = $(Setfield.set)($obj, $lens_var, $val) end else - op = get_update_op(ex.head) + op = Setfield.get_update_op(ex.head) f = :($(Setfield._UpdateOp)($op, $val)) quote $lens_var = ($lenstransform)($lens) From 505d6902e2dd799b6a6b029c1e36033f05f4d664 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:38:58 +0100 Subject: [PATCH 10/32] fixed a doctests that will inevitably fail on Julia 1.3 --- src/compiler.jl | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index b967b74f0..3c78ad944 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -93,24 +93,17 @@ variables. # Examples ```jldoctest; setup=:(using Distributions) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns -2-element Vector{VarName{:x, Setfield.IndexLens{Tuple{Colon, Int64}}}}: - x[:,1] - x[:,2] - -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns -1×2 Matrix{VarName{:x, Setfield.IndexLens{Tuple{Int64, Int64}}}}: - x[1,1] x[1,2] - -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns -1×2 Matrix{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Colon}}, Setfield.IndexLens{Tuple{Int64, Int64}}}}}: - x[:][1,1] x[:][1,2] - -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns -3-element Vector{VarName{:x, Setfield.ComposedLens{Setfield.IndexLens{Tuple{Int64}}, Setfield.IndexLens{Tuple{Int64}}}}}: - x[1][1] - x[1][2] - x[1][3] +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); vns[end] +x[:,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] +x[1,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] +x[:][1,2] + +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] +x[1][3] ``` """ unwrap_right_left_vns(right, left, vns) = right, left, vns From 9f8c47b364234de991a3ba07c0a168f1afe56323 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:41:00 +0100 Subject: [PATCH 11/32] updated a comment --- src/compiler.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3c78ad944..709c49f5a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -423,9 +423,8 @@ function generate_dot_tilde_assume(left, right, vn) ) end -# TODO: Make PR to Setfield.jl to use `gensym` for the `lens` variable. -# This seems like it should be the case anyways since it allows multiple -# calls to `make_set` without any cost to the current functionality. +# TODO: Replace with `setmacro` once https://github.com/jw3126/Setfield.jl/pull/156 +# has been merged. function make_set(lenstransform, ex::Expr; overwrite::Bool=false) @assert ex.head isa Symbol @assert length(ex.args) == 2 From 0a4795306dd24d364d3057fe04c8e7b40ee3cfe3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:41:22 +0100 Subject: [PATCH 12/32] added deprecations of the tildes --- src/DynamicPPL.jl | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 344a92594..b2cd80b9e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -132,4 +132,36 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# `tilde_assume` +@deprecate( + tilde_assume(rng, context, sampler, right, left, vn, inds, vi), + tilde_assume(rng, context, sampler, right, left, vn, vi) +) +@deprecate( + tilde_assume(context, right, left, vn, inds, vi), + tilde_assume(context, right, left, vn, vi) +) + +# `tilde_observe` +@deprecate( + tilde_observe(context, right, left, vn, inds, vi), + tilde_observe(context, right, left, vn, vi) +) +# Need to specify the `sampler` type here to avoid clashing with +# the deprecation above since they have the same number of arguments. +@deprecate( + tilde_observe(context, sampler::AbstractSampler, right, left, inds, vi), + tilde_observe(context, sampler::AbstractSampler, right, left, vi) +) + +# `dot_tilde_assume` +@deprecate( + dot_tilde_assume(context, right, left, vns, inds, vi), + dot_tilde_assume(context, right, left, vns, vi) +) +@deprecate( + dot_tilde_assume(rng, context, sampler, right, left, vns, inds, vi), + dot_tilde_assume(rng, context, sampler, right, left, vns, vi) +) + end # module From 0c51329f10d043e211dea904aee567e1081b2fe1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 31 Jul 2021 11:57:55 +0100 Subject: [PATCH 13/32] Update src/DynamicPPL.jl --- src/DynamicPPL.jl | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b2cd80b9e..344a92594 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -132,36 +132,4 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -# `tilde_assume` -@deprecate( - tilde_assume(rng, context, sampler, right, left, vn, inds, vi), - tilde_assume(rng, context, sampler, right, left, vn, vi) -) -@deprecate( - tilde_assume(context, right, left, vn, inds, vi), - tilde_assume(context, right, left, vn, vi) -) - -# `tilde_observe` -@deprecate( - tilde_observe(context, right, left, vn, inds, vi), - tilde_observe(context, right, left, vn, vi) -) -# Need to specify the `sampler` type here to avoid clashing with -# the deprecation above since they have the same number of arguments. -@deprecate( - tilde_observe(context, sampler::AbstractSampler, right, left, inds, vi), - tilde_observe(context, sampler::AbstractSampler, right, left, vi) -) - -# `dot_tilde_assume` -@deprecate( - dot_tilde_assume(context, right, left, vns, inds, vi), - dot_tilde_assume(context, right, left, vns, vi) -) -@deprecate( - dot_tilde_assume(rng, context, sampler, right, left, vns, inds, vi), - dot_tilde_assume(rng, context, sampler, right, left, vns, vi) -) - end # module From 8134a1a591985663329cfbc500c4acc6ab526cd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 01:28:58 +0100 Subject: [PATCH 14/32] use impl of get for VarName instead of the hacky stuff we currently have --- src/context_implementations.jl | 43 ++++++++++------------------------ test/runtests.jl | 2 +- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 868f13b84..824be77d0 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -36,30 +36,22 @@ end # Leaf contexts tilde_assume(::DefaultContext, right, vn, vi) = assume(right, vn, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi -) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) end function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, + rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) @@ -73,7 +65,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) @@ -84,11 +76,10 @@ function tilde_assume( sampler, right, vn, - inds, vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(getfield(context.vars, getsym(vn)), vn.indexing)) settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) @@ -96,9 +87,7 @@ end function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi -) +function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) return assume(rng, sampler, NoDist(right), vn, vi) end @@ -297,11 +286,9 @@ function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, vi) end # `LikelihoodContext` -function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi -) +function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) @@ -317,11 +304,10 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) @@ -342,7 +328,7 @@ end # `PriorContext` function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) @@ -358,11 +344,10 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) @@ -385,9 +370,7 @@ function dot_tilde_assume(context::MiniBatchContext, right, left, vn, vi) return dot_tilde_assume(context.context, right, left, vn, vi) end -function dot_tilde_assume( - rng, context::MiniBatchContext, sampler, right, left, vn, vi -) +function dot_tilde_assume(rng, context::MiniBatchContext, sampler, right, left, vn, vi) return dot_tilde_assume(rng, context.context, sampler, right, left, vn, vi) end diff --git a/test/runtests.jl b/test/runtests.jl index d83be0eea..68f4facaf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("test_util.jl") include("threadsafe.jl") - include("serialization.jl") + # include("serialization.jl") include("loglikelihoods.jl") end From 9d3c1dd32f3504f4a7e4938f3187e4fb18ae6ef7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 01:29:37 +0100 Subject: [PATCH 15/32] uncomment commented out test suite --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 68f4facaf..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("test_util.jl") include("threadsafe.jl") - # include("serialization.jl") + include("serialization.jl") include("loglikelihoods.jl") end From 4121b0e7990da382d994961f0cf8c09525c2458b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 02:39:04 +0100 Subject: [PATCH 16/32] use BangBang to prefer mutation when using set --- Project.toml | 3 +++ src/DynamicPPL.jl | 1 + src/compiler.jl | 2 +- test/compiler.jl | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 553420689..11f7c334a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -16,9 +17,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Setfield = "0.7.0" ZygoteRules = "0.2" julia = "1.3" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 344a92594..1d227303a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -11,6 +11,7 @@ using MacroTools: MacroTools using ZygoteRules: ZygoteRules using Setfield: Setfield +using BangBang: BangBang using Random: Random diff --git a/src/compiler.jl b/src/compiler.jl index 709c49f5a..60e2ddf72 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -374,7 +374,7 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr - AbstractPPL.drop_escape(make_set(identity, expr; overwrite=true)) + AbstractPPL.drop_escape(make_set(BangBang.prefermutation, expr; overwrite=true)) else return expr end diff --git a/test/compiler.jl b/test/compiler.jl index 6f85e9453..370476f7f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -28,6 +28,11 @@ macro mymodel2(ex) end end +# Used to test sampling of immutable types. +struct MyCoolStruct{T} + a::T +end + @testset "compiler.jl" begin @testset "model macro" begin @model function testmodel_comp(x, y) @@ -229,6 +234,51 @@ end @test haskey(vi.metadata, :x) vi = VarInfo(gdemo(x)) @test haskey(vi.metadata, :x) + + # Non-array variables + @model function testmodel_nonarray(x, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 2:length(x.a) - 1 + x.a[i] ~ Normal(m, √s) + end + + # Dynamic indexing + x.a[begin] ~ Normal(-100.0, 1.0) + x.a[end] ~ Normal(100.0, 1.0) + + # Immutable set + y.a ~ Normal() + + # Dotted + z = Vector{Float64}(undef, 3) + z[1:2] .~ Normal() + z[end:end] .~ Normal() + + return (; s, m, x, y, z) + end + + m_nonarray = testmodel_nonarray(MyCoolStruct([missing, missing]), MyCoolStruct(missing)); + result = m_nonarray() + @test !any(ismissing, result.x.a) + @test result.y.a !== missing + @test result.x.a[begin] < -10 + @test result.x.a[end] > 10 + + # Ensure that we can work with `Vector{Real}(undef, N)` which is the + # reason why we're using `BangBang.prefermutation` in `src/compiler.jl` + # rather than the default from Setfield.jl. + # Related: https://github.com/jw3126/Setfield.jl/issues/157 + @model function vdemo() + x = Vector{Real}(undef, 10) + for i in eachindex(x) + x[i] ~ Normal(0, sqrt(4)) + end + + return x + end + x = vdemo()() + @test all((isassigned(x, i) for i in eachindex(x))) end @testset "nested model" begin function makemodel(p) From a38168ad8c7c932ac9fc4831d2e3d39a902e9ae9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 02:40:58 +0100 Subject: [PATCH 17/32] remove redundant and outdated tests for VarInfo in integration tests --- test/turing/varinfo.jl | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index b72832b78..e81e99006 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -184,29 +184,6 @@ chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) end @testset "varname" begin - i, j, k = 1, 2, 3 - - vn1 = @varname x[1] - @test vn1 == VarName(:x, ((1,),)) - - # Symbol - v_sym = string(:x) - @test v_sym == "x" - - # Array - v_arr = @varname x[i] - @test v_arr.indexing == ((1,),) - - # Matrix - v_mat = @varname x[i, j] - @test v_mat.indexing == ((1, 2),) - - v_mat = @varname x[i, j, k] - @test v_mat.indexing == ((1, 2, 3),) - - v_mat = @varname x[1, 2][1 + 5][45][3][i] - @test v_mat.indexing == ((1, 2), (6,), (45,), (3,), (1,)) - @model function mat_name_test() p = Array{Any}(undef, 2, 2) for i in 1:2, j in 1:2 @@ -217,10 +194,6 @@ chain = sample(mat_name_test(), HMC(0.2, 4), 1000) check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) - # Multi array - v_arrarr = @varname x[i][j] - @test v_arrarr.indexing == ((1,), (2,)) - @model function marr_name_test() p = Array{Array{Any}}(undef, 2) p[1] = Array{Any}(undef, 2) From 51e7426d8c00708cded89254329265290661356b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 02:41:48 +0100 Subject: [PATCH 18/32] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/compiler.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 370476f7f..ab6d42665 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -239,7 +239,7 @@ end @model function testmodel_nonarray(x, y) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) - for i in 2:length(x.a) - 1 + for i in 2:(length(x.a) - 1) x.a[i] ~ Normal(m, √s) end @@ -258,12 +258,14 @@ end return (; s, m, x, y, z) end - m_nonarray = testmodel_nonarray(MyCoolStruct([missing, missing]), MyCoolStruct(missing)); + m_nonarray = testmodel_nonarray( + MyCoolStruct([missing, missing]), MyCoolStruct(missing) + ) result = m_nonarray() - @test !any(ismissing, result.x.a) - @test result.y.a !== missing - @test result.x.a[begin] < -10 - @test result.x.a[end] > 10 + @test !any(ismissing, result.x.a) + @test result.y.a !== missing + @test result.x.a[begin] < -10 + @test result.x.a[end] > 10 # Ensure that we can work with `Vector{Real}(undef, N)` which is the # reason why we're using `BangBang.prefermutation` in `src/compiler.jl` From 0a3655dc701ff16b2c7cacff2b6c30da2638ae2e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 06:04:17 +0100 Subject: [PATCH 19/32] no longer need the custom make_set method after Setfield v0.7.1 --- Project.toml | 2 +- src/compiler.jl | 28 +--------------------------- 2 files changed, 2 insertions(+), 28 deletions(-) diff --git a/Project.toml b/Project.toml index 11f7c334a..07a6c4db6 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" -Setfield = "0.7.0" +Setfield = "0.7.1" ZygoteRules = "0.2" julia = "1.3" diff --git a/src/compiler.jl b/src/compiler.jl index 60e2ddf72..55d93d858 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -374,7 +374,7 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr - AbstractPPL.drop_escape(make_set(BangBang.prefermutation, expr; overwrite=true)) + AbstractPPL.drop_escape(Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)) else return expr end @@ -423,32 +423,6 @@ function generate_dot_tilde_assume(left, right, vn) ) end -# TODO: Replace with `setmacro` once https://github.com/jw3126/Setfield.jl/pull/156 -# has been merged. -function make_set(lenstransform, ex::Expr; overwrite::Bool=false) - @assert ex.head isa Symbol - @assert length(ex.args) == 2 - ref, val = ex.args - obj, lens = Setfield.parse_obj_lens(ref) - lens_var = gensym("lens") - dst = overwrite ? obj : gensym("_") - val = esc(val) - ret = if ex.head == :(=) - quote - $lens_var = ($lenstransform)($lens) - $dst = $(Setfield.set)($obj, $lens_var, $val) - end - else - op = Setfield.get_update_op(ex.head) - f = :($(Setfield._UpdateOp)($op, $val)) - quote - $lens_var = ($lenstransform)($lens) - $dst = $(Setfield.modify)($f, $obj, $lens_var) - end - end - return ret -end - const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true From f82579e4619c60f073611fd1c1ce9f6c90a411b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 06:06:35 +0100 Subject: [PATCH 20/32] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 55d93d858..ded0db7f1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -374,7 +374,9 @@ function generate_tilde_assume(left, right, vn) ) return if left isa Expr - AbstractPPL.drop_escape(Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)) + AbstractPPL.drop_escape( + Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + ) else return expr end From 9aa7298aa78010bacd749c2e7ad30f1e6072efde Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 17:10:30 +0100 Subject: [PATCH 21/32] drop concretize argument to varname --- src/compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index ded0db7f1..9ad28a526 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(AbstractPPL.drop_escape(varname(expr, true))) + let $vn = $(AbstractPPL.drop_escape(varname(expr))) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || @@ -348,7 +348,7 @@ function generate_tilde(left, right) # that in DynamicPPL we the entire function body. Instead we should be # more selective with our escape. Until that's the case, we remove them all. return quote - $vn = $(AbstractPPL.drop_escape(varname(left, true))) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $(generate_tilde_assume(left, right, vn)) @@ -394,7 +394,7 @@ function generate_dot_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption return quote - $vn = $(AbstractPPL.drop_escape(varname(left, true))) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) From b130db92dfb9ee75791cb5e4b906cb338114aaff Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 17:17:03 +0100 Subject: [PATCH 22/32] added a couple of additional benchmarks --- benchmarks/benchmarks.jmd | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index 614afb2e9..5b86b261e 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -94,3 +94,37 @@ data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ```julia; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` + +### `demo4`: loads of indexing + +```julia +@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4 +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia +@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + x .~ Normal(m, 1.0) +end + +model_def = demo4_dotted +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` From 475da88cbf5ef23119efa10f74b1a19fc51b6592 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Aug 2021 05:33:37 +0100 Subject: [PATCH 23/32] fixed tests --- src/compiler.jl | 2 +- src/context_implementations.jl | 16 ++++++++-------- src/contexts.jl | 2 +- test/Project.toml | 2 ++ test/contexts.jl | 10 +++++----- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 87f5a4dba..25530696d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -19,7 +19,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(varname(expr)) + let $vn = $(AbstractPPL.drop_escape(varname(expr))) if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b258f4480..e70456dd9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -47,7 +47,7 @@ end function tilde_assume(context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), context, args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) return assume(right, vn, vi) end function tilde_assume(::IsParent, context::AbstractContext, args...) @@ -58,7 +58,7 @@ function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end function tilde_assume( - ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi + ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vi ) return assume(rng, sampler, right, vn, vi) end @@ -111,11 +111,11 @@ function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, rig return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +function tilde_assume(context::PrefixContext, right, vn, vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) end -function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) end """ @@ -276,11 +276,11 @@ function dot_tilde_assume(rng, context::AbstractContext, args...) return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) end -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end function dot_tilde_assume( - ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi + ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi ) return dot_assume(rng, sampler, right, vns, left, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index 98eb4b85d..64887f032 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -311,7 +311,7 @@ Return value of `vn` in `context`. function getvalue(context::AbstractContext, vn) return error("context $(context) does not contain value for $vn") end -getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) +getvalue(context::ConditionContext, vn) = get(context.values, vn) """ hasvalue_nested(context, vn) diff --git a/test/Project.toml b/test/Project.toml index c523f6092..4ff5f9d19 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -28,6 +29,7 @@ Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" MCMCChains = "4.0.4" MacroTools = "0.5.5" +Setfield = "0.7.1" StableRNGs = "1" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" diff --git a/test/contexts.jl b/test/contexts.jl index 79b8c75fa..647dab7c1 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL +using Test, DynamicPPL, Setfield using DynamicPPL: leafcontext, setleafcontext, @@ -65,11 +65,11 @@ e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1 """ varnames(vn::VarName, val::Real) = [vn] function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) + return (VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val)) end function varnames(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + varnames(VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end @@ -183,7 +183,7 @@ end # Let's check elementwise. for vn_child in varnames(vn_without_prefix, val) - if DynamicPPL._getindex(val, vn_child.indexing) === missing + if get(val, vn_child.indexing) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -219,7 +219,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - DynamicPPL._getindex(val, vn_child.indexing) + get(val, vn_child.indexing) end end end From 4af7e3061161ba936455d292fa8aef021b3bc8ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Aug 2021 05:41:40 +0100 Subject: [PATCH 24/32] formatting --- src/context_implementations.jl | 10 +++------- test/contexts.jl | 5 ++++- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e70456dd9..1d6116580 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -57,9 +57,7 @@ end function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end -function tilde_assume( - ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vi -) +function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end function tilde_assume(::IsParent, rng, context::AbstractContext, args...) @@ -99,7 +97,7 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(getfield(context.vars, getsym(vn)), vn.indexing)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) @@ -279,9 +277,7 @@ end function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume( - ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi -) +function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end diff --git a/test/contexts.jl b/test/contexts.jl index 647dab7c1..7144d455b 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -65,7 +65,10 @@ e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1 """ varnames(vn::VarName, val::Real) = [vn] function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return (VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val)) + return ( + VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) end function varnames(vn::VarName, val::AbstractArray) return Iterators.flatten( From e03ef4ef51c3968904df39e907b821365bd371e7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Aug 2021 05:44:29 +0100 Subject: [PATCH 25/32] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1d6116580..26dfb4ac1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -97,7 +97,7 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(context.vars, vn)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) From 6dd6de943d173d4fd9d7a8cb364c96bd62240ac0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 00:11:12 +0100 Subject: [PATCH 26/32] bumped APPL compat bound --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6ba224a38..1bb1ba7ea 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10, 1" diff --git a/test/Project.toml b/test/Project.toml index b05100045..3af6ef22d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" Bijectors = "0.9.5" Distributions = "0.25" DistributionsAD = "0.6.3" From 4c7e8828745b3e42076d4812be3bdc6bb1ddfd79 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:28:22 +0100 Subject: [PATCH 27/32] some bugfixes --- Project.toml | 2 +- src/context_implementations.jl | 12 ------------ src/contexts.jl | 4 ++-- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 1bb1ba7ea..c6c776135 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.15.1" +version = "0.16.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 26dfb4ac1..19b5ce061 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,18 +14,6 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x -_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) -function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} - val = getproperty(x, sym) - - # This should work with both cartesian and linear indexing. - return map(vns) do vn - _getindex(val, vn) - end -end - # assume """ tilde_assume(context::SamplingContext, right, vn, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 64887f032..03fc26245 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -246,9 +246,9 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn))) else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn)) end end From 4c325c339fc3a4381e23cb6a284cd659a5addee6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:32:23 +0100 Subject: [PATCH 28/32] updated tests --- test/contexts.jl | 14 +++++++------- test/turing/varinfo.jl | 23 ----------------------- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 7144d455b..edf581d4d 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -53,7 +53,7 @@ Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - vn.indexing + getlens(vn) ) end @@ -66,13 +66,13 @@ e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1 varnames(vn::VarName, val::Real) = [vn] function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))) for + VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end function varnames(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames(VarName(vn, vn.indexing ∘ Setfield.IndexLens(Tuple(I))), val[I]) for + varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end @@ -186,7 +186,7 @@ end # Let's check elementwise. for vn_child in varnames(vn_without_prefix, val) - if get(val, vn_child.indexing) === missing + if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -222,7 +222,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - get(val, vn_child.indexing) + get(val, getlens(vn_child)) end end end @@ -249,11 +249,11 @@ end vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) end end diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index 68b1acd49..892433779 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -184,29 +184,6 @@ chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) end @testset "varname" begin - i, j, k = 1, 2, 3 - - vn1 = @varname x[1] - @test vn1 == VarName{:x}(((1,),)) - - # Symbol - v_sym = string(:x) - @test v_sym == "x" - - # Array - v_arr = @varname x[i] - @test v_arr.indexing == ((1,),) - - # Matrix - v_mat = @varname x[i, j] - @test v_mat.indexing == ((1, 2),) - - v_mat = @varname x[i, j, k] - @test v_mat.indexing == ((1, 2, 3),) - - v_mat = @varname x[1, 2][1 + 5][45][3][i] - @test v_mat.indexing == ((1, 2), (6,), (45,), (3,), (1,)) - @model function mat_name_test() p = Array{Any}(undef, 2, 2) for i in 1:2, j in 1:2 From 553ae9b71fd12d6e5c40359b9e63aa7c885059fb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 22:50:25 +0100 Subject: [PATCH 29/32] drop testing begin indexing since incomp with Julia 1.3 --- test/compiler.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 231f11f4c..2aba7df0d 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -250,7 +250,6 @@ end end # Dynamic indexing - x.a[begin] ~ Normal(-100.0, 1.0) x.a[end] ~ Normal(100.0, 1.0) # Immutable set @@ -270,7 +269,6 @@ end result = m_nonarray() @test !any(ismissing, result.x.a) @test result.y.a !== missing - @test result.x.a[begin] < -10 @test result.x.a[end] > 10 # Ensure that we can work with `Vector{Real}(undef, N)` which is the From 00d841124c211d23a5e24826d119bd17ddf6834c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 23:15:33 +0100 Subject: [PATCH 30/32] fixed tests I think --- test/compiler.jl | 4 ++-- test/turing/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 2aba7df0d..dd1764561 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -183,7 +183,7 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.sampler + global sampler_ = __context__.s2ampler global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng @@ -260,7 +260,7 @@ end z[1:2] .~ Normal() z[end:end] .~ Normal() - return (; s, m, x, y, z) + return (; s=s, m=m, x=x, y=y, z=z) end m_nonarray = testmodel_nonarray( diff --git a/test/turing/Project.toml b/test/turing/Project.toml index fe186816f..9d75e2dcb 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.15" +DynamicPPL = "0.16" Turing = "0.18" julia = "1.3" From a99a2b1822ba4095d6916b757bf74fab8d855218 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Sep 2021 00:10:52 +0100 Subject: [PATCH 31/32] we try again --- test/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index dd1764561..3140cf5b7 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -183,7 +183,7 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __context__.s2ampler + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ global rng_ = __context__.rng From 472629dcbae01dfa7eee8fccfdee89ef06455193 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 10 Sep 2021 03:13:30 +0100 Subject: [PATCH 32/32] fixed test --- test/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index 3140cf5b7..9f6c0163f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -245,7 +245,7 @@ end @model function testmodel_nonarray(x, y) s ~ InverseGamma(2, 3) m ~ Normal(0, √s) - for i in 2:(length(x.a) - 1) + for i in 1:(length(x.a) - 1) x.a[i] ~ Normal(m, √s) end