Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turning around escape in model macro #311

Closed
wants to merge 12 commits into from
88 changes: 53 additions & 35 deletions src/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
const INTERNALNAMES = (:__model__, :__context__, :__varinfo__)

for name in INTERNALNAMES
@eval $(Symbol(uppercase(string(name)))) = $(Meta.quot(name))
end
phipsgabler marked this conversation as resolved.
Show resolved Hide resolved

# macro _id(expr)
# return expr
# end

# macro hygienize(expr)
# return Meta.quot(macroexpand(__module__, :(@_id $expr)))
# end

"""
isassumption(expr)
isassumption(expr, vn)

Return an expression that can be evaluated to check if `expr` is an assumption in the
model.
Expand All @@ -14,38 +26,37 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases:

When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`.
"""
function isassumption(expr::Union{Symbol,Expr})
vn = gensym(:vn)

function isassumption(expr::Union{Symbol,Expr}, vn)
return quote
let $vn = $(varname(expr))
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$(maybe_view(expr)) === missing
end
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
$(DynamicPPL.inmissings)($vn, __model__)
$(DynamicPPL.inmissings)($vn, __model__)

true
else
false
$(maybe_view(expr)) === missing
end
else
false
end
end
end

# failsafe: a literal is never an assumption
isassumption(expr, vn) = :(false)

"""
contextual_isassumption(context, vn)

Expand Down Expand Up @@ -79,9 +90,6 @@ function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

# If we're working with, say, a `Symbol`, then we're not going to `view`.
maybe_view(x) = x
maybe_view(x::Expr) = :(@views($x))
Expand Down Expand Up @@ -306,6 +314,15 @@ generate_mainbody(mod, expr, warn) = generate_mainbody!(mod, Symbol[], expr, war

generate_mainbody!(mod, found, x, warn) = x
function generate_mainbody!(mod, found, sym::Symbol, warn)
if sym in DEPRECATED_INTERNALNAMES
newsym = Symbol(:_, sym, :__)
Base.depwarn(
"internal variable `$sym` is deprecated, use `$newsym` instead.",
:generate_mainbody!,
)
return generate_mainbody!(mod, found, newsym, warn)
end

phipsgabler marked this conversation as resolved.
Show resolved Hide resolved
if warn && sym in INTERNALNAMES && sym ∉ found
@warn "you are using the internal variable `$sym`"
push!(found, sym)
Expand Down Expand Up @@ -371,7 +388,7 @@ function generate_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left = $(DynamicPPL.tilde_assume!)(
__context__,
Expand Down Expand Up @@ -420,7 +437,7 @@ function generate_dot_tilde(left, right)
return quote
$vn = $(varname(left))
$inds = $(vinds(left))
$isassumption = $(DynamicPPL.isassumption(left))
$isassumption = $(DynamicPPL.isassumption(left, vn))
if $isassumption
$left .= $(DynamicPPL.dot_tilde_assume!)(
__context__,
Expand Down Expand Up @@ -465,15 +482,16 @@ function build_output(modelinfo, linenumbernode)
# Add the internal arguments to the user-specified arguments (positional + keywords).
evaluatordef[:args] = vcat(
[
:(__model__::$(DynamicPPL.Model)),
:(__varinfo__::$(DynamicPPL.AbstractVarInfo)),
:(__context__::$(DynamicPPL.AbstractContext)),
:($__MODEL__::$(DynamicPPL.Model)),
:($__VARINFO__::$(DynamicPPL.AbstractVarInfo)),
:($__CONTEXT__::$(DynamicPPL.AbstractContext)),
],
modelinfo[:allargs_exprs],
)

# Delete the keyword arguments.
evaluatordef[:kwargs] = []
evaluatordef[:name] = esc(evaluatordef[:name])

# Replace the user-provided function body with the version created by DynamicPPL.
# We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure
Expand All @@ -485,7 +503,6 @@ function build_output(modelinfo, linenumbernode)
end

## Build the model function.

# Extract the named tuple expression of all arguments and the default values.
allargs_namedtuple = modelinfo[:allargs_namedtuple]
defaults_namedtuple = modelinfo[:defaults_namedtuple]
Expand All @@ -495,6 +512,7 @@ function build_output(modelinfo, linenumbernode)
# that no new `LineNumberNode`s are added apart from the reference `linenumbernode`
# to the call site
modeldef = modelinfo[:modeldef]
modeldef[:name] = esc(modeldef[:name])
modeldef[:body] = MacroTools.@q begin
$(linenumbernode)
return $(DynamicPPL.Model)(
Expand Down