diff --git a/src/compiler.jl b/src/compiler.jl index 8ad248622..79ea63c9a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1,7 +1,12 @@ const INTERNALNAMES = (:__model__, :__context__, :__varinfo__) +for name in INTERNALNAMES + @eval const $(Symbol(uppercase(string(name)))) = $(Meta.quot(name)) +end + + """ - isassumption(expr) + isassumption(expr, vn) Return an expression that can be evaluated to check if `expr` is an assumption in the model. @@ -14,38 +19,37 @@ Let `expr` be `:(x[1])`. It is an assumption in the following cases: When `expr` is not an expression or symbol (i.e., a literal), this expands to `false`. """ -function isassumption(expr::Union{Symbol,Expr}) - vn = gensym(:vn) - +function isassumption(expr::Union{Symbol,Expr}, vn) return quote - let $vn = $(AbstractPPL.drop_escape(varname(expr))) - if $(DynamicPPL.contextual_isassumption)(__context__, $vn) - # Considered an assumption by `__context__` which means either: - # 1. We hit the default implementation, e.g. using `DefaultContext`, - # which in turn means that we haven't considered if it's one of - # the model arguments, hence we need to check this. - # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, - # i.e. we're trying to condition one of the latent variables. - # In this case, the below will return `true` since the first branch - # will be hit. - # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, - # i.e. we're trying to override the value. This is currently NOT supported. - # TODO: Support by adding context to model, and use `model.args` - # as the default conditioning. Then we no longer need to check `inargnames` - # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) - true - else - $(maybe_view(expr)) === missing - end + if $(DynamicPPL.contextual_isassumption)($__CONTEXT__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, $__MODEL__)) || + $(DynamicPPL.inmissings)($vn, $__MODEL__) + true else - false + $(maybe_view(expr)) === missing end + else + false end end end +# failsafe: a literal is never an assumption +isassumption(expr, vn) = :(false) + """ contextual_isassumption(context, vn) @@ -79,9 +83,6 @@ function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) end -# failsafe: a literal is never an assumption -isassumption(expr) = :(false) - # If we're working with, say, a `Symbol`, then we're not going to `view`. maybe_view(x) = x maybe_view(x::Expr) = :(@views($x)) @@ -314,15 +315,13 @@ function generate_mainbody!(mod, found, sym::Symbol, warn) return sym end function generate_mainbody!(mod, found, expr::Expr, warn) - # Do not touch interpolated expressions - expr.head === :$ && return expr.args[1] - - # Do we don't want escaped expressions because we unfortunately - # escape the entire body afterwards. - Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) - - # If it's a macro, we expand it - if Meta.isexpr(expr, :macrocall) + if Meta.isexpr(expr, :$) + # Do not touch interpolated expressions + return expr.args[1] + elseif Meta.isexpr(expr, :escape) + return generate_mainbody(mod, found, expr.args[1], warn) + elseif Meta.isexpr(expr, :macrocall) + # If it's a macro, we expand it (recursively) return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) end @@ -357,7 +356,7 @@ function generate_tilde_literal(left, right) # If the LHS is a literal, it is always an observation return quote $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + $__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $left, $__VARINFO__ ) end end @@ -375,26 +374,23 @@ function generate_tilde(left, right) # if the LHS represents an observation @gensym vn isassumption - # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact - # that in DynamicPPL we the entire function body. Instead we should be - # more selective with our escape. Until that's the case, we remove them all. return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) + if !$(DynamicPPL.inargnames)($vn, $__MODEL__) + $left = $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn) end $(DynamicPPL.tilde_observe!)( - __context__, + $__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - __varinfo__, + $__VARINFO__, ) end end @@ -403,14 +399,14 @@ end function generate_tilde_assume(left, right, vn) expr = :( $left = $(DynamicPPL.tilde_assume!)( - __context__, + $__CONTEXT__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., - __varinfo__, + $__VARINFO__, ) ) - return if left isa Expr - AbstractPPL.drop_escape( + if left isa Expr + return AbstractPPL.drop_escape( Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) ) else @@ -431,21 +427,21 @@ function generate_dot_tilde(left, right) @gensym vn isassumption return quote $vn = $(AbstractPPL.drop_escape(varname(left))) - $isassumption = $(DynamicPPL.isassumption(left)) + $isassumption = $(DynamicPPL.isassumption(left, vn)) if $isassumption $(generate_dot_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. - if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) + if !$(DynamicPPL.inargnames)($vn, $__MODEL__) + $left .= $(DynamicPPL.getvalue_nested)($__CONTEXT__, $vn) end $(DynamicPPL.dot_tilde_observe!)( - __context__, + $__CONTEXT__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - __varinfo__, + $__VARINFO__, ) end end @@ -457,11 +453,11 @@ function generate_dot_tilde_assume(left, right, vn) # be something that supports `.=`. return :( $left .= $(DynamicPPL.dot_tilde_assume!)( - __context__, + $__CONTEXT__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn )..., - __varinfo__, + $__VARINFO__, ) ) end @@ -479,15 +475,14 @@ Builds the output expression. function build_output(modelinfo, linenumbernode) ## Build the anonymous evaluator from the user-provided model definition. evaluatordef = deepcopy(modelinfo[:modeldef]) + original_arguments = modelinfo[:allargs_exprs] # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( - [ - :(__model__::$(DynamicPPL.Model)), - :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__context__::$(DynamicPPL.AbstractContext)), - ], - modelinfo[:allargs_exprs], + :($__MODEL__::$(DynamicPPL.Model)), + :($__VARINFO__::$(DynamicPPL.AbstractVarInfo)), + :($__CONTEXT__::$(DynamicPPL.AbstractContext)), + original_arguments, ) # Delete the keyword arguments. @@ -513,10 +508,11 @@ function build_output(modelinfo, linenumbernode) # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` # to the call site modeldef = modelinfo[:modeldef] + modelname_symbol = Meta.quot(modeldef[:name]) modeldef[:body] = MacroTools.@q begin $(linenumbernode) return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), + $modelname_symbol, $(modeldef[:name]), $allargs_namedtuple, $defaults_namedtuple,