Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for optional positional arguments in the dynamic DSL. #195

Merged
merged 7 commits into from
Mar 2, 2020
16 changes: 12 additions & 4 deletions docs/src/ref/gfi.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ GenerativeFunction
There are various kinds of generative functions, which are represented by concrete subtypes of [`GenerativeFunction`](@ref).
For example, the [Built-in Modeling Language](@ref) allows generative functions to be constructed using Julia function definition syntax:
```julia
@gen function foo(a, b)
@gen function foo(a, b=0)
if @trace(bernoulli(0.5), :z)
return a + b + 1
else
Expand All @@ -26,7 +26,7 @@ Users can also extend Gen by implementing their own [Custom generative function
Generative functions behave like Julia functions in some respects.
For example, we can call a generative function `foo` on arguments and get an output value using regular Julia call syntax:
```julia-repl
>julia foo(2, 4)
julia> foo(2, 4)
7
```
However, generative functions are distinct from Julia functions because they support additional behaviors, described in the remainder of this section.
Expand Down Expand Up @@ -103,7 +103,7 @@ Traces contain:

- the arguments to the generative function

- the choice map
- the choice map

- the return value

Expand Down Expand Up @@ -148,6 +148,13 @@ For example, to retrieve the value of random choice at address `:z`:
z = trace[:z]
```

When a generative function has default values specified for trailing arguments, those arguments can be left out when calling [`simulate`](@ref), [`generate`](@ref), and other functions provided by the generative function interface. The default values will automatically be filled in:
```julia
julia> trace = simulate(foo, (2,));
julia> get_args(trace)
(2, 0)
```

## Updating traces

It is often important to incrementally modify the trace of a generative function (e.g. within MCMC, numerical optimization, sequential Monte Carlo, etc.).
Expand Down Expand Up @@ -287,7 +294,7 @@ Then `get_choices(new_trace)` will be:
├── :a : true
├── :b : true
├── :b : true
├── :c : false
Expand All @@ -302,6 +309,7 @@ In addition to the input trace, and other arguments that indicate how to adjust
The args argument contains the new arguments to the generative function, which may differ from the previous arguments to the generative function (which can be retrieved by applying [`get_args`](@ref) to the previous trace).
In many cases, the adjustment to the execution specified by the other arguments to these methods is 'small' and only affects certain parts of the computation.
Therefore, it is often possible to generate the new trace and the appropriate log probability ratios required for these methods without revisiting every state of the computation of the generative function.

To enable this, the argdiffs argument provides additional information about the *difference* between each of the previous arguments to the generative function, and its new argument value.
This argdiff information permits the implementation of the update method to avoid inspecting the entire argument data structure to identify which parts were updated.
Note that the correctness of the argdiff is in general not verified by Gen---passing incorrect argdiff information may result in incorrect behavior.
Expand Down
14 changes: 12 additions & 2 deletions docs/src/ref/modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The language uses a syntax that extends Julia's syntax for defining regular Juli
Generative functions in the modeling language are identified using the `@gen` keyword in front of a Julia function definition.
Here is an example `@gen` function that samples two random choices:
```julia
@gen function foo(prob::Float64)
@gen function foo(prob::Float64=0.1)
z1 = @trace(bernoulli(prob), :a)
z2 = @trace(bernoulli(prob), :b)
return z1 || z2
Expand All @@ -17,6 +17,8 @@ After running this code, `foo` is a Julia value of type [`DynamicDSLFunction`](@
DynamicDSLFunction
```

Note that it is possible to provide default values for trailing positional arguments. However, keyword arguments are currently *not* supported.

We can call the resulting generative function like we would a regular Julia function:
```julia
retval::Bool = foo(0.5)
Expand All @@ -25,6 +27,12 @@ We can also trace its execution:
```julia
(trace, _) = generate(foo, (0.5,))
```
Optional arguments can be left out of the above operations, and default values will be filled in automatically:
```julia
julia> (trace, _) = generate(foo, (,));
julia> get_args(trace)
(0.1,)
```
See [Generative Functions](@ref) for the full set of operations supported by a generative function.
Note that the built-in modeling language described in this section is only one of many ways of defining a generative function -- generative functions can also be constructed using other embedded languages, or by directly implementing the methods of the generative function interface.
However, the built-in modeling language is intended to being flexible enough cover a wide range of use cases.
Expand Down Expand Up @@ -393,7 +401,9 @@ The trace statement must use a literal Julia symbol for the first component in t
return z4
```

The functions must also satisfy the following rules:
The functions must also satisfy the following rules:

- Default argument values are not supported.

- `@trace` expressions cannot appear anywhere in the function body except for as the outer-most expression on the right-hand side of a statement.

Expand Down
11 changes: 9 additions & 2 deletions src/dsl/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ struct Argument
name::Symbol
typ::Union{Symbol,Expr}
annotations::Set{Symbol}
default::Union{Some{Any}, Nothing}
end

Argument(name, typ) = Argument(name, typ, Set{Symbol}())
Argument(name, typ) = Argument(name, typ, Set{Symbol}(), nothing)
Argument(name, typ, annotations) = Argument(name, typ, annotations, nothing)

function parse_annotations(annotations_expr)
annotations = Set{Symbol}()
Expand All @@ -39,12 +41,17 @@ function parse_arg(expr)
elseif isa(expr, Expr) && expr.head == :(::)
# x::Int
arg = Argument(expr.args[1], expr.args[2])
elseif isa(expr, Expr) && expr.head == :kw
# x::Int=1
sub_arg = parse_arg(expr.args[1])
default = Some(expr.args[2])
arg = Argument(sub_arg.name, sub_arg.typ, Set{Symbol}(), default)
elseif isa(expr, Expr) && expr.head == :call
# (grad,foo)(x::Int)
annotations_expr = expr.args[1]
sub_arg = parse_arg(expr.args[2])
annotations = parse_annotations(annotations_expr)
arg = Argument(sub_arg.name, sub_arg.typ, annotations)
arg = Argument(sub_arg.name, sub_arg.typ, annotations, sub_arg.default)
else
dump(expr)
error("syntax error in gen function argument at $expr")
Expand Down
18 changes: 17 additions & 1 deletion src/dsl/dynamic.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
const DYNAMIC_DSL_TRACE = Symbol("@trace")

function arg_to_ast(arg::Argument)
ast = esc(arg.name)
if (arg.default != nothing)
default = something(arg.default)
ast = Expr(:kw, ast, esc(default))
end
ast
end

function escape_default(arg)
(arg.default == nothing ? nothing :
Expr(:call, :Some, esc(something(arg.default))))
end

function make_dynamic_gen_function(name, args, body, return_type, annotations)
escaped_args = map((arg) -> esc(arg.name), args)
escaped_args = map(arg_to_ast, args)
gf_args = [esc(state), escaped_args...]

julia_fn_name = gensym(name)
julia_fn_defn = Expr(:function,
Expr(:call, esc(julia_fn_name), gf_args...),
esc(body))
arg_types = map((arg) -> esc(arg.typ), args)
arg_defaults = map(escape_default, args)
has_argument_grads = map(
(arg) -> (DSL_ARG_GRAD_ANNOTATION in arg.annotations), args)
accepts_output_grad = DSL_RET_GRAD_ANNOTATION in annotations
Expand All @@ -20,6 +35,7 @@ function make_dynamic_gen_function(name, args, body, return_type, annotations)
# now wrap it in a DynamicDSLFunction value
Core.@__doc__ $(esc(name)) = DynamicDSLFunction(
Type[$(arg_types...)],
Union{Some{Any},Nothing}[$(arg_defaults...)],
$(esc(julia_fn_name)),
$has_argument_grads,
$(esc(return_type)),
Expand Down
7 changes: 5 additions & 2 deletions src/dsl/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function parse_trace_expr!(stmts, bindings, name, addr_expr, typ)
args = (call.args[2:end]..., reverse(keys[2:end])...)
end
node = gensym()
if haskey(bindings, name)
if haskey(bindings, name)
static_dsl_syntax_error(addr_expr, "Symbol $name already bound")
end
bindings[name] = node
Expand Down Expand Up @@ -256,11 +256,14 @@ function make_static_gen_function(name, args, body, return_type, annotations)
push!(stmts, :(set_accepts_output_grad!(builder, $(QuoteNode(accepts_output_grad)))))
bindings = Dict{Symbol,Symbol}() # map from variable name to node name
for arg in args
if arg.default != nothing
error("Default argument values not supported in the static DSL.")
end
node = gensym()
push!(stmts, :($(esc(node)) = add_argument_node!(
builder, name=$(QuoteNode(arg.name)), typ=$(QuoteNode(arg.typ)),
compute_grad=$(QuoteNode(DSL_ARG_GRAD_ANNOTATION in arg.annotations)))))
bindings[arg.name] = node
bindings[arg.name] = node
end
parse_static_dsl_function_body!(stmts, bindings, body)
push!(stmts, :(ir = build_ir(builder)))
Expand Down
15 changes: 15 additions & 0 deletions src/dynamic/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,37 @@ struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
params_grad::Dict{Symbol,Any}
params::Dict{Symbol,Any}
arg_types::Vector{Type}
has_defaults::Bool
arg_defaults::Vector{Union{Some{Any},Nothing}}
julia_function::Function
has_argument_grads::Vector{Bool}
accepts_output_grad::Bool
end

function DynamicDSLFunction(arg_types::Vector{Type},
arg_defaults::Vector{Union{Some{Any},Nothing}},
julia_function::Function,
has_argument_grads, ::Type{T},
accepts_output_grad::Bool) where {T}
params_grad = Dict{Symbol,Any}()
params = Dict{Symbol,Any}()
has_defaults = any(arg -> arg != nothing, arg_defaults)
DynamicDSLFunction{T}(params_grad, params, arg_types,
has_defaults, arg_defaults,
julia_function,
has_argument_grads, accepts_output_grad)
end

function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction}
# pad args with default values, if available
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)
defaults = gen_fn.arg_defaults[length(args)+1:end]
defaults = map(x -> something(x), defaults)
args = Tuple(vcat(collect(args), defaults))
end
DynamicDSLTrace{T}(gen_fn, args)
end

accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad

function (gen_fn::DynamicDSLFunction)(args...)
Expand Down
4 changes: 1 addition & 3 deletions src/dynamic/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function CallRecord(record::ChoiceOrCallRecord)
end
CallRecord(record.subtrace_or_retval, record.score, record.noise)
end

mutable struct DynamicDSLTrace{T} <: Trace
gen_fn::T
trie::Trie{Any,ChoiceOrCallRecord}
Expand All @@ -45,8 +45,6 @@ mutable struct DynamicDSLTrace{T} <: Trace
end
end

DynamicDSLTrace(gen_fn::T, args) where {T} = DynamicDSLTrace{T}(gen_fn, args)

set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval)

function has_choice(trace::DynamicDSLTrace, addr)
Expand Down
42 changes: 38 additions & 4 deletions src/gen_fn_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ Execute the generative function and return the trace.

Given arguments (`args`), sample \$t \\sim p(\\cdot; x)\$ and \$r \\sim p(\\cdot; x,
t)\$, and return a trace with choice map \$t\$.

If `gen_fn` has optional trailing arguments (i.e., default values are provided),
the optional arguments can be omitted from the `args` tuple. The generated trace
will have default values filled in.
"""
function simulate(::GenerativeFunction, ::Tuple)
error("Not implemented")
Expand All @@ -158,12 +162,16 @@ Return a trace of a generative function that is consistent with the given
constraints on the random choices.

Given arguments \$x\$ (`args`) and assignment \$u\$ (`constraints`) (which is empty for the first form), sample \$t \\sim
q(\\cdot; u, x)\$ and \$r \\sim q(\\cdot; x, t)\$, and return the trace \$(x, t, r)\$ (`trace`).
q(\\cdot; u, x)\$ and \$r \\sim q(\\cdot; x, t)\$, and return the trace \$(x, t, r)\$ (`trace`).
Also return the weight (`weight`):
```math
\\log \\frac{p(t, r; x)}{q(t; u, x) q(r; x, t)}
```

If `gen_fn` has optional trailing arguments (i.e., default values are provided),
the optional arguments can be omitted from the `args` tuple. The generated trace
will have default values filled in.

Example without constraints:
```julia
(trace, weight) = generate(foo, (2, 4))
Expand All @@ -186,7 +194,7 @@ end
weight = project(trace::U, selection::Selection)

Estimate the probability that the selected choices take the values they do in a
trace.
trace.

Given a trace \$(x, t, r)\$ (`trace`) and a set of addresses \$A\$ (`selection`),
let \$u\$ denote the restriction of \$t\$ to \$A\$. Return the weight
Expand Down Expand Up @@ -223,7 +231,7 @@ end
Return the probability of proposing an assignment

Given arguments \$x\$ (`args`) and an assignment \$t\$ (`choices`) such that
\$p(t; x) > 0\$, sample \$r \\sim q(\\cdot; x, t)\$ and
\$p(t; x) > 0\$, sample \$r \\sim q(\\cdot; x, t)\$ and
return the weight (`weight`):
```math
\\log \\frac{p(r, t; x)}{q(r; x, t)}
Expand Down Expand Up @@ -256,8 +264,15 @@ return a weight (`weight`):
```math
\\log \\frac{p(r', t'; x') q(r; x, t)}{p(r, t; x) q(r'; x', t') q(t'; x', t + u)}
```

Note that `argdiffs` is expected to be the same length as `args`. If the
function that generated `trace` supports default values for trailing arguments,
then these arguments can be omitted from `args` and `argdiffs`. Note
that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the updated trace.
"""
function update(trace, ::Tuple, argdiffs::Tuple, ::ChoiceMap)
function update(trace, args::Tuple, argdiffs::Tuple, ::ChoiceMap)
error("Not implemented")
end

Expand All @@ -280,6 +295,13 @@ Return the new trace \$(x', t', r')\$ (`new_trace`) and the weight
\\log \\frac{p(r', t'; x') q(t; u', x) q(r; x, t)}{p(r, t; x) q(t'; u, x') q(r'; x', t')}
```
where \$u'\$ is the restriction of \$t'\$ to the complement of \$A\$.

Note that `argdiffs` is expected to be the same length as `args`. If the
function that generated `trace` supports default values for trailing arguments,
then these arguments can be omitted from `args` and `argdiffs`. Note
that if the original `trace` was generated using non-default argument values,
then for each optional argument that is omitted, the old value will be
over-written by the default argument value in the regenerated trace.
"""
function regenerate(trace, args::Tuple, argdiffs::Tuple, selection::Selection)
error("Not implemented")
Expand All @@ -298,6 +320,12 @@ with respect to the arguments \$x\$:
```math
∇_x \\left( \\log P(t; x) + J \\right)
```

The length of `arg_grads` will be equal to the number of arguments to the
function that generated `trace` (including any optional trailing arguments).
If an argument is not annotated with `(grad)`, the corresponding value in
`arg_grads` will be `nothing`.

Also increment the gradient accumulators for the trainable parameters \$Θ\$ of
the function by:
```math
Expand All @@ -320,6 +348,12 @@ with respect to the arguments \$x\$:
```math
∇_x \\left( \\log P(t; x) + J \\right)
```

The length of `arg_grads` will be equal to the number of arguments to the
function that generated `trace` (including any optional trailing arguments).
If an argument is not annotated with `(grad)`, the corresponding value in
`arg_grads` will be `nothing`.

Also given a set of addresses \$A\$ (`selection`) that are continuous-valued
random choices, return the folowing gradient (`choice_grads`) with respect to
the values of these choices:
Expand Down
6 changes: 4 additions & 2 deletions src/modeling_library/map/map.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
##################
# map combinator #
# map combinator #
##################

# used for type dispatch on the VectorTrace type (e.g. we will also have a UnfoldType)
struct MapType end
struct MapType end

"""
gen_fn = Map(kernel::GenerativeFunction)
Expand All @@ -15,6 +15,8 @@ The length of each argument, which must be the same for each argument, determine
Each call to the input function is made under address namespace i for i=1..N.
The return value of the returned function has type `FunctionalCollections.PersistentVector{Y}` where `Y` is the type of the return value of the input function.
The map combinator is similar to the 'map' higher order function in functional programming, except that the map combinator returns a new generative function that must then be separately applied.

If `kernel` has optional trailing arguments, the corresponding `Vector` arguments can be omitted from calls to `Map(kernel)`.
"""
struct Map{T,U} <: GenerativeFunction{PersistentVector{T},VectorTrace{MapType,T,U}}
kernel::GenerativeFunction{T,U}
Expand Down
2 changes: 2 additions & 0 deletions src/modeling_library/unfold/unfold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The returned generative function accepts the following arguments:
- The rest of the arguments (not including the state) that will be passed to each kernel application.

The return type of the returned generative function is `FunctionalCollections.PersistentVector{T}` where `T` is the return type of the kernel.

If `kernel` has optional trailing arguments, the corresponding arguments can be omitted from calls to `Unfold(kernel)`.
"""
struct Unfold{T,U} <: GenerativeFunction{PersistentVector{T},VectorTrace{UnfoldType,T,U}}
kernel::GenerativeFunction{T,U}
Expand Down
Loading