Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turning around escape in model macro #311

Closed
wants to merge 12 commits into from
126 changes: 61 additions & 65 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

for name in INTERNALNAMES
@eval const $(Symbol(uppercase(string(name)))) = $(Meta.quot(name))
end


phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
"""
isassumption(expr)
isassumption(expr, vn)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -14,38 +19,37 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Symbol,Expr}, vn)
return quote
let $vn = $(AbstractPPL.drop_escape(varname(expr)))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)($__CONTEXT__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, $__MODEL__)) ||
$(DynamicPPL.inmissings)($vn, $__MODEL__)
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr, vn) = :(false)

"""
contextual_isassumption(context, vn)

Expand Down Expand Up @@ -79,9 +83,6 @@ function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -314,15 +315,13 @@ function generate_mainbody!(mod, found, sym::Symbol, warn)
return sym
end
function generate_mainbody!(mod, found, expr::Expr, warn)
# Do not touch interpolated expressions
expr.head === :$ && return expr.args[1]

# Do we don't want escaped expressions because we unfortunately
# escape the entire body afterwards.
Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn)

# If it's a macro, we expand it
if Meta.isexpr(expr, :macrocall)
if Meta.isexpr(expr, :$)
# Do not touch interpolated expressions
return expr.args[1]
elseif Meta.isexpr(expr, :escape)
return generate_mainbody(mod, found, expr.args[1], warn)
elseif Meta.isexpr(expr, :macrocall)
# If it's a macro, we expand it (recursively)
return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn)
end

Expand Down Expand Up @@ -357,7 +356,7 @@ function generate_tilde_literal(left, right)
# If the LHS is a literal, it is always an observation
return quote
$(DynamicPPL.tilde_observe!)(
__context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__
$__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__
)
end
end
Expand All @@ -375,26 +374,23 @@ function generate_tilde(left, right)
# if the LHS represents an observation
@gensym vn isassumption

# HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact
# that in DynamicPPL we the entire function body. Instead we should be
# more selective with our escape. Until that's the case, we remove them all.
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)($vn, $__MODEL__)
$left = $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -403,14 +399,14 @@ end
function generate_tilde_assume(left, right, vn)
expr = :(
$left = $(DynamicPPL.tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)...,
__varinfo__,
$__VARINFO__,
)
)

return if left isa Expr
AbstractPPL.drop_escape(
if left isa Expr
return AbstractPPL.drop_escape(
Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true)
)
else
Expand All @@ -431,21 +427,21 @@ function generate_dot_tilde(left, right)
@gensym vn isassumption
return quote
$vn = $(AbstractPPL.drop_escape(varname(left)))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$(generate_dot_tilde_assume(left, right, vn))
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue_nested)(__context__, $vn)
if !$(DynamicPPL.inargnames)($vn, $__MODEL__)
$left .= $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.check_tilde_rhs)($right),
$(maybe_view(left)),
$vn,
__varinfo__,
$__VARINFO__,
)
end
end
Expand All @@ -457,11 +453,11 @@ function generate_dot_tilde_assume(left, right, vn)
# be something that supports `.=`.
return :(
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
$__CONTEXT__,
$(DynamicPPL.unwrap_right_left_vns)(
$(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn
)...,
__varinfo__,
$__VARINFO__,
)
)
end
Expand All @@ -479,15 +475,14 @@ Builds the output expression.
function build_output(modelinfo, linenumbernode)
## Build the anonymous evaluator from the user-provided model definition.
evaluatordef = deepcopy(modelinfo[:modeldef])
original_arguments = modelinfo[:allargs_exprs]

# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
original_arguments,
)

# Delete the keyword arguments.
Expand All @@ -513,10 +508,11 @@ function build_output(modelinfo, linenumbernode)
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modelname_symbol = Meta.quot(modeldef[:name])
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)(
$(QuoteNode(modeldef[:name])),
$modelname_symbol,
$(modeldef[:name]),
$allargs_namedtuple,
$defaults_namedtuple,
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
Expand Down