Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turning around escape in model macro #311

Closed
wants to merge 12 commits into from
102 changes: 56 additions & 46 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
const INTERNALNAMES = (:__model__, :__sampler__, :__context__, :__varinfo__, :__rng__)
const DEPRECATED_INTERNALNAMES = (:_model, :_sampler, :_context, :_varinfo, :_rng)

for name in INTERNALNAMES
@eval $(Symbol(uppercase(string(name)))) = $(Meta.quot((name)))
end


phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
# macro _id(expr)
# return expr
# end

# macro hygienize(expr)
# return Meta.quot(macroexpand(__module__, :(@_id $expr)))
# end

"""
isassumption(expr)

Expand All @@ -15,34 +28,30 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Symbol,Expr}, vn)
return quote
let $vn = $(varname(expr))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)($__CONTEXT__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, $__MODEL__)) ||
$(DynamicPPL.inmissings)($vn, $__MODEL__)
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end
Expand Down Expand Up @@ -201,7 +210,7 @@ To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`.
macro model(expr, warn=false)
# include `LineNumberNode` with information about the call site in the
# generated function for easier debugging and interpretation of error messages
return esc(model(__module__, __source__, expr, warn))
return model(__module__, __source__, expr, warn)
end

function model(mod, linenumbernode, expr, warn)
Expand Down Expand Up @@ -325,7 +334,7 @@ function generate_mainbody!(mod, found, sym::Symbol, warn)
end
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]
Meta.isexpr(expr, :$) && return esc(expr.args[1])

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
Expand Down Expand Up @@ -370,7 +379,7 @@ function generate_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
)
end
end
Expand All @@ -381,29 +390,29 @@ function generate_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_vn)(
$(DynamicPPL.check_tilde_rhs)($right), $vn
)...,
$inds,
__varinfo__,
$__VARINFO__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)($vn, $__MODEL__)
$left = $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -419,7 +428,7 @@ function generate_dot_tilde(left, right)
if isliteral(left)
return quote
$(DynamicPPL.dot_tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
)
end
end
Expand All @@ -430,29 +439,29 @@ function generate_dot_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
$inds,
__varinfo__,
$__VARINFO__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)(vn, $__MODEL__)
$left .= $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
$inds,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -478,9 +487,9 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)
Expand All @@ -500,6 +509,7 @@ function build_output(modelinfo, linenumbernode)
# Update the function body of the user-specified model.
# We use a name for the anonymous evaluator that does not conflict with other variables.
modeldef = modelinfo[:modeldef]
modeldef[:name] = esc(modeldef[:name])
@gensym evaluator
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
Expand Down