diff --git a/.travis.yml b/.travis.yml index 36f0f9500..6270d6015 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,9 @@ julia: - 1.4 - 1.5 - 1.6 +env: + - JULIA_NUM_THREADS=1 + - JULIA_NUM_THREADS=2 services: - docker diff --git a/docs/make.jl b/docs/make.jl index 3c2448478..7c82d4dfd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -29,7 +29,6 @@ makedocs( "Learning Generative Functions" => "ref/learning.md" ], "Internals" => [ - "Optimizing Trainable Parameters" => "ref/internals/parameter_optimization.md", "Modeling Language Implementation" => "ref/internals/language_implementation.md" ] ] diff --git a/docs/src/index.md b/docs/src/index.md index 161192b99..98aaf390a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,7 +6,7 @@ Pages = [ "getting_started.md", "tutorials.md", - "guide.md", + "guide.md" ] Depth = 2 ``` @@ -21,8 +21,8 @@ Pages = [ "ref/parameter_optimization.md", "ref/inference.md", "ref/gfi.md", - "ref/distributions.md" - "ref/extending.md", + "ref/distributions.md", + "ref/extending.md" ] Depth = 2 ``` diff --git a/docs/src/ref/gfi.md b/docs/src/ref/gfi.md index dde929722..a0a40126c 100644 --- a/docs/src/ref/gfi.md +++ b/docs/src/ref/gfi.md @@ -344,6 +344,7 @@ A generative function statically reports whether or not it is able to compute gr The **trainable parameters** of a generative function are (unlike arguments and random choices) *state* of the generative function itself, and are not contained in the trace. Generative functions that have trainable parameters maintain *gradient accumulators* for these parameters, which get incremented by the gradient induced by the given trace by a call to [`accumulate_param_gradients!`](@ref). Users then use these accumulated gradients to update to the values of the trainable parameters. +Use [`get_parameters`](@ref) to obtain the full set of trainable parameters that a generative function uses (see [Optimizing Trainable Paramters](@ref) for more details). ### Return value gradient The set of elements (either arguments, random choices, or trainable parameters) for which gradients are available is called the **gradient source set**. @@ -371,5 +372,5 @@ has_argument_grads accepts_output_grad accumulate_param_gradients! choice_gradients -get_params +get_parameters ``` diff --git a/docs/src/ref/internals/parameter_optimization.md b/docs/src/ref/internals/parameter_optimization.md index 48d88d494..e69de29bb 100644 --- a/docs/src/ref/internals/parameter_optimization.md +++ b/docs/src/ref/internals/parameter_optimization.md @@ -1,7 +0,0 @@ -# [Optimizing Trainable Parameters](@id optimizing-internal) - -To add support for a new type of gradient-based parameter update, create a new type with the following methods defined for the types of generative functions that are to be supported. -```@docs -Gen.init_update_state -Gen.apply_update! -``` diff --git a/docs/src/ref/learning.md b/docs/src/ref/learning.md index 464d6d284..496dd051e 100644 --- a/docs/src/ref/learning.md +++ b/docs/src/ref/learning.md @@ -51,10 +51,10 @@ end ``` Let's suppose we are training the generative model. -The first step is to initialize the values of the trainable parameters, which for generative functions constructed using the built-in modeling languages, we do with [`init_param!`](@ref): +The first step is to initialize the values of the trainable parameters, which for generative functions constructed using the built-in modeling languages, we do with [`init_parameter!`](@ref): ```julia -init_param!(model, :a, 0.) -init_param!(model, :b, 0.) +init_parameter!((model, :a), 0.0) +init_parameter!((model, :b), 0.0) ``` Each trace in the collection contains the observed data from an independent draw from our model. We can populate each trace with its observed data using [`generate`](@ref): @@ -76,29 +76,29 @@ for trace in traces accumulate_param_gradients!(trace) end ``` -Finally, we can construct and gradient-based update with [`ParamUpdate`](@ref) and apply it with [`apply!`](@ref). +Finally, we can construct and gradient-based update with [`init_optimizer`](@ref) and apply it with [`apply_update!`](@ref). We can put this all together into a function: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) + init_parameter!((model, :theta), 0.1) traces = [] for observations in data trace, = generate(model, model_args, observations) push!(traces, trace) end - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter objective = sum([get_score(trace) for trace in traces]) println("objective: $objective") for trace in traces accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) end end ``` -Note that using the same primitives ([`generate`](@ref) and [`accumulate_param_gradients!`](@ref)), you can compose various more sophisticated learning algorithms involving e.g. stochastic gradient descent and minibatches, and more sophisticated stochastic gradient optimizers like [`ADAM`](@ref). +Note that using the same primitives ([`generate`](@ref) and [`accumulate_param_gradients!`](@ref)), you can compose various more sophisticated learning algorithms involving e.g. stochastic gradient descent and minibatches, and more sophisticated stochastic gradient optimizers. For example, [`train!`](@ref) trains a generative function from complete data with minibatches. ## Learning from Incomplete Data @@ -139,14 +139,14 @@ There are many variants possible, based on which Monte Carlo inference algorithm For example: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + init_parameter!((model, :theta), 0.1) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter traces = do_monte_carlo_inference(data) for trace in traces accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) end end @@ -160,14 +160,14 @@ end Note that it is also possible to use a weighted collection of traces directly without resampling: ```julia function train_model(data::Vector{ChoiceMap}) - init_param!(model, :theta, 0.1) - update = ParamUpdate(FixedStepSizeGradientDescent(0.001), model) + init_parameter!((model, :theta), 0.1) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), model) for iter=1:max_iter traces, weights = do_monte_carlo_inference_with_weights(data) for (trace, weight) in zip(traces, weights) accumulate_param_gradients!(trace, nothing, weight) end - apply!(update) + apply_update!(optimizer) end end ``` @@ -209,7 +209,7 @@ Then, the traces of the model can be obtained by simulating from the variational Instead of fitting the variational approximation from scratch for each observation, it is possible to fit an *inference model* instead, that takes as input the observation, and generates a distribution on latent variables as output (as in the wake sleep algorithm). When we train the variational approximation by minimizing the evidence lower bound (ELBO) this is called amortized variational inference. Variational autencoders are an example. -It is possible to perform amortized variational inference using [`black_box_vi`](@ref) or [`black_box_vimco!`](@ref). +It is possible to perform amortized variational inference using [`black_box_vi!`](@ref) or [`black_box_vimco!`](@ref). ## References diff --git a/docs/src/ref/modeling.md b/docs/src/ref/modeling.md index 15be1990e..d55015284 100644 --- a/docs/src/ref/modeling.md +++ b/docs/src/ref/modeling.md @@ -254,6 +254,7 @@ See [Generative Function Interface](@ref) for more information about traces. A `@gen` function may begin with an optional block of *trainable parameter declarations*. The block consists of a sequence of statements, beginning with `@param`, that declare the name and Julia type for each trainable parameter. +The Julia type must be either a subtype of `Real` or subtype of `Array{<:Real}`. The function below has a single trainable parameter `theta` with type `Float64`: ```julia @gen function foo(prob::Float64) @@ -264,23 +265,22 @@ The function below has a single trainable parameter `theta` with type `Float64`: end ``` Trainable parameters obey the same scoping rules as Julia local variables defined at the beginning of the function body. -The value of a trainable parameter is undefined until it is initialized using [`init_param!`](@ref). +After the definition of the generative function, you must register all of the parameters used by the generative function using [`register_parameters!`](@ref) (this is not required if you instead use the [Static Modeling Language](@ref)): +```julia +register_parameters!(foo, [:theta]) +``` +The value of a trainable parameter is undefined until it is initialized using [`init_parameter!`](@ref): +```julia +init_parameter!((foo, :theta), 0.0) +``` In addition to the current value, each trainable parameter has a current **gradient accumulator** value. The gradient accumulator value has the same shape (e.g. array dimension) as the parameter value. -It is initialized to all zeros, and is incremented by [`accumulate_param_gradients!`](@ref). - -The following methods are exported for the trainable parameters of `@gen` functions: +It is initialized to all zeros, and is incremented by calling [`accumulate_param_gradients!`](@ref) on a trace. +Additional functions for retrieving and manipulating the values of trainable parameters and their gradient accumulators are described in [Optimizing Trainable Parameters](@ref). ```@docs -init_param! -get_param -get_param_grad -set_param! -zero_param_grad! +register_parameters! ``` -Trainable parameters are designed to be trained using gradient-based methods. -This is discussed in the next section. - ## Differentiable programming Given a trace of a `@gen` function, Gen supports automatic differentiation of the log probability (density) of all of the random choices made in the trace with respect to the following types of inputs: diff --git a/docs/src/ref/parameter_optimization.md b/docs/src/ref/parameter_optimization.md index 60e05f41d..943fbbaa8 100644 --- a/docs/src/ref/parameter_optimization.md +++ b/docs/src/ref/parameter_optimization.md @@ -1,33 +1,82 @@ # Optimizing Trainable Parameters -Trainable parameters of generative functions are initialized differently depending on the type of generative function. -Trainable parameters of the built-in modeling language are initialized with [`init_param!`](@ref). +## Parameter stores -Gradient-based optimization of the trainable parameters of generative functions is based on interleaving two steps: +Multiple traces of a generative function typically reference the same trainable parameters of the generative function, which are stored outside of the trace in a **parameter store**. +Different types of generative functions may use different types of parameter stores. +For example, the [`JuliaParameterStore`](@ref) (discussed below) stores parameters as Julia values in the memory of the Julia runtime process. +Other types of parameter stores may store parameters in GPU memory, in a filesystem, or even remotely. -- Incrementing gradient accumulators for trainable parameters by calling [`accumulate_param_gradients!`](@ref) on one or more traces. +When generating a trace of a generative function with [`simulate`](@ref) or [`generate`](@ref), we may pass in an optional **parameter context**, which is a `Dict` that provides information about which parameter store(s) in which to look up the value of parameters. +A generative function obtains a reference to a specific type of parameter store by looking up its key in the parameter context. -- Updating the value of trainable parameters and resetting the gradient accumulators to zero, by calling [`apply!`](@ref) on a *parameter update*, as described below. +If you are just learning Gen, and are only using the built-in modeling language to write generative functions, you can ignore this complexity, because there is a [`default_julia_parameter_store`](@ref) and a default parameter context [`default_parameter_context`](@ref) that points to this default Julia parameter store that will be used if a parameter context is not provided in the call to `simulate` and `generate`. +```@docs +default_parameter_context +default_julia_parameter_store +``` + +## Julia parameter store + +Parameters declared using the `@param` keyword in the built-in modeling language are stored in a type of parameter store called a [`JuliaParameterStore`](@ref). +A generative function can obtain a reference to a `JuliaParameterStore` by looking up the key [`JULIA_PARAMETER_STORE_KEY`](@ref) in a parameter context. +This is how the built-in modeling language implementation finds the parameter stores to use for `@param`-declared parameters. +Note that if you are defining your own [custom generative functions](@ref #Custom-generative-functions), you can also use a [`JuliaParameterStore`](@ref) (including the same parameter store used to store parameters of built-in modeling language generative functions) to store and optimize your trainable parameters. -## Parameter update +Different types of parameter stores provide different APIs for reading, writing, and updating the values of parameters and gradient accumulators for parameters. +The `JuliaParameterStore` API is given below. +The API uses tuples of the form `(gen_fn::GenerativeFunction, name::Symbol)` to identify parameters. +(Note that most user learning code only needs to use [`init_parameter!`](@ref), as the other API functions are called by [Optimizers](@ref) which are discussed below.) -A *parameter update* reads from the gradient accumulators for certain trainable parameters, updates the values of those parameters, and resets the gradient accumulators to zero. -A paramter update is constructed by combining an *update configuration* with the set of trainable parameters to which the update should be applied: ```@docs -ParamUpdate +JuliaParameterStore +init_parameter! +increment_gradient! +reset_gradient! +get_parameter_value +get_gradient +JULIA_PARAMETER_STORE_KEY ``` -The set of possible update configurations is described in [Update configurations](@ref). -An update is applied with: + +### Multi-threaded gradient accumulation + +Note that the [`increment_gradient!`](@ref) call is thread-safe, so that multiple threads can concurrently increment the gradient for the same parameters. This is helpful for parallelizing gradient computation for a batch of traces within stochastic gradient descent learning algorithms. + +## Optimizers + +Gradient-based optimization typically involves iterating between two steps: +(i) computing gradients or estimates of gradients with respect to parameters, and +(ii) updating the value of the parameters based on the gradient estimates according to some mathematical rule. +Sometimes the optimization algorithm also has its own state that is separate from the value of the parameters and the gradient estimates. +Gradient-based optimization algorithms in Gen are implemented by **optimizers**. +Each type of parameter store provides implementations of optimizers for standard mathematical update rules. + +The mathematical rules are defined in **optimizer configuration** objects. +The currently supported optimizer configurations are: ```@docs -apply! +FixedStepGradientDescent +DecayStepGradientDescent +``` + +The most common way to construct an optimizer is via: +```julia +optimizer = init_optimizer(conf, gen_fn) ``` +which returns an optimizer that applies the mathematical rule defined by `conf` to all parameters used by `gen_fn` (even when the generative function uses parameters that are housed in multiple parameter stores). +You can also pass a parameter context keyword argument to customize the parameter store(s) that the optimizer should use. +Then, after accumulating gradients with [`accumulate_param_gradients!`](@ref), you can apply the update with: +```julia +apply_update!(optimizer) +``` + +The `init_optimizer` method described above constructs an optimizer that actually invokes multiple optimizers, one for each parameter store. +To add support to a parameter store type for a new optimizer configuration type, you must implement the per-parameter-store optimizer methods: -## Update configurations +- `init_optimizer(conf, parameter_ids, store)`, which takes in an optimizer configuration object, and list of parameter IDs, and the parameter store in which to apply the updates, and returns an optimizer thata mutates the given parameter store. + +- `apply_update!(optimizer)`, which takes in an a single argument (the optimizer) and applies its update rule, which mutates the value of the parameters in its parameter store (and typically also resets the values of the gradient accumulators to zero). -Gen has built-in support for the following types of update configurations. ```@docs -FixedStepGradientDescent -GradientDescent -ADAM +init_optimizer +apply_update! ``` -For adding new types of update configurations, see [Optimizing Trainable Parameters (Internal)](@ref optimizing-internal). diff --git a/docs/src/ref/selections.md b/docs/src/ref/selections.md index 172e31876..986173335 100644 --- a/docs/src/ref/selections.md +++ b/docs/src/ref/selections.md @@ -55,5 +55,4 @@ AllSelection HierarchicalSelection DynamicSelection StaticSelection -ComplementSelection ``` diff --git a/src/Gen.jl b/src/Gen.jl index a3a40e38b..13a540400 100644 --- a/src/Gen.jl +++ b/src/Gen.jl @@ -64,9 +64,6 @@ include("dynamic/dynamic.jl") # static IR generative function include("static_ir/static_ir.jl") -# optimization for built-in generative functions (dynamic and static IR) -include("builtin_optimization.jl") - # DSLs for defining dynamic embedded and static IR generative functions # 'Dynamic DSL' and 'Static DSL' include("dsl/dsl.jl") diff --git a/src/builtin_optimization.jl b/src/builtin_optimization.jl deleted file mode 100644 index 3a897577f..000000000 --- a/src/builtin_optimization.jl +++ /dev/null @@ -1,120 +0,0 @@ -""" - set_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - -Set the value of a trainable parameter of the generative function. - -NOTE: Does not update the gradient accumulator value. -""" -function set_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - gf.params[name] = value -end - -""" - value = get_param(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Get the current value of a trainable parameter of the generative function. -""" -function get_param(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params[name] -end - -""" - value = get_param_grad(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Get the current value of the gradient accumulator for a trainable parameter of the generative function. -""" -function get_param_grad(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params_grad[name] -end - -""" - zero_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - -Reset the gradient accumlator for a trainable parameter of the generative function to all zeros. -""" -function zero_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol) - gf.params_grad[name] = zero(gf.params[name]) -end - -""" - set_param_grad!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) - -Set the gradient accumlator for a trainable parameter of the generative function. -""" -function set_param_grad!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, grad_value) - gf.params_grad[name] = grad_value -end - -""" - init_param!(gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - -Initialize the the value of a named trainable parameter of a generative function. - -Also generates the gradient accumulator for that parameter to `zero(value)`. - -Example: -```julia -init_param!(foo, :theta, 0.6) -``` -""" -function init_param!(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, name::Symbol, value) - set_param!(gf, name, value) - zero_param_grad!(gf, name) -end - -get_params(gf::Union{DynamicDSLFunction,StaticIRGenerativeFunction}) = keys(gf.params) - -export set_param!, get_param, get_param_grad, zero_param_grad!, set_param_grad!, init_param! - -######################################### -# gradient descent with fixed step size # -######################################### - -mutable struct FixedStepGradientDescentBuiltinDSLState - step_size::Float64 - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction} - param_list::Vector -end - -function init_update_state(conf::FixedStepGradientDescent, - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector) - FixedStepGradientDescentBuiltinDSLState(conf.step_size, gen_fn, param_list) -end - -function apply_update!(state::FixedStepGradientDescentBuiltinDSLState) - for param_name in state.param_list - value = get_param(state.gen_fn, param_name) - grad = get_param_grad(state.gen_fn, param_name) - set_param!(state.gen_fn, param_name, value + grad * state.step_size) - zero_param_grad!(state.gen_fn, param_name) - end -end - -#################### -# gradient descent # -#################### - -mutable struct GradientDescentBuiltinDSLState - step_size_init::Float64 - step_size_beta::Float64 - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction} - param_list::Vector - t::Int -end - -function init_update_state(conf::GradientDescent, - gen_fn::Union{DynamicDSLFunction,StaticIRGenerativeFunction}, param_list::Vector) - GradientDescentBuiltinDSLState(conf.step_size_init, conf.step_size_beta, - gen_fn, param_list, 1) -end - -function apply_update!(state::GradientDescentBuiltinDSLState) - step_size = state.step_size_init * (state.step_size_beta + 1) / (state.step_size_beta + state.t) - for param_name in state.param_list - value = get_param(state.gen_fn, param_name) - grad = get_param_grad(state.gen_fn, param_name) - set_param!(state.gen_fn, param_name, value + grad * step_size) - zero_param_grad!(state.gen_fn, param_name) - end - state.t += 1 -end diff --git a/src/dynamic/assess.jl b/src/dynamic/assess.jl index c583d5079..ffb7a9378 100644 --- a/src/dynamic/assess.jl +++ b/src/dynamic/assess.jl @@ -2,11 +2,22 @@ mutable struct GFAssessState choices::ChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFAssessState(gen_fn, choices, parameter_context) + new(choices, 0.0, AddressVisitor(), gen_fn, parameter_context) + end end -function GFAssessState(choices, params::Dict{Symbol,Any}) - GFAssessState(choices, 0., AddressVisitor(), params) +get_parameter_store(state::GFAssessState) = get_julia_store(state.parameter_context) + +get_parameter_id(state::GFAssessState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFAssessState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFAssessState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFAssessState, dist::Distribution{T}, @@ -22,7 +33,7 @@ function traceat(state::GFAssessState, dist::Distribution{T}, # update weight state.weight += logpdf(dist, retval, args...) - retval + return retval end function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, @@ -41,19 +52,13 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U}, # update score state.weight += weight - retval -end - -function splice(state::GFAssessState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval + return retval end -function assess(gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap) - state = GFAssessState(choices, gen_fn.params) +function assess( + gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) + state = GFAssessState(gen_fn, choices, parameter_context) retval = exec(gen_fn, state, args) unvisited = get_unvisited(get_visited(state.visitor), choices) @@ -61,5 +66,5 @@ function assess(gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap) error("Assess did not visit the following constraint addresses:\n$unvisited") end - (state.weight, retval) + return (state.weight, retval) end diff --git a/src/dynamic/backprop.jl b/src/dynamic/backprop.jl index 9e02c0657..77e22e1ae 100644 --- a/src/dynamic/backprop.jl +++ b/src/dynamic/backprop.jl @@ -31,9 +31,9 @@ end end -################### +############################### # accumulate_param_gradients! # -################### +############################### mutable struct GFBackpropParamsState trace::DynamicDSLTrace @@ -41,33 +41,29 @@ mutable struct GFBackpropParamsState tape::InstructionTape visitor::AddressVisitor scale_factor::Float64 - - # only those tracked parameters that are in scope (not including splice calls) - tracked_params::Dict{Symbol,Any} - - # tracked parameters for all (nested) splice calls - splice_tracked_params::Dict{Tuple{GenerativeFunction,Symbol},Any} -end - -function track_params(tape, params) - tracked_params = Dict{Symbol,Any}() - for (name, value) in params - tracked_params[name] = track(value, tape) + active_gen_fn::GenerativeFunction + tracked_params::Dict{Tuple{GenerativeFunction,Symbol},Any} + + function GFBackpropParamsState(trace::DynamicDSLTrace, tape, scale_factor) + tracked_params = Dict{Tuple{GenerativeFunction,Symbol},Any}() + store = get_parameter_store(trace) + gen_fn = get_gen_fn(trace) + for (name, value) in get_local_parameters(store, gen_fn) + parameter_id = (gen_fn, name) + tracked_params[parameter_id] = track(value, tape) + end + score = track(0., tape) + new(trace, score, tape, AddressVisitor(), scale_factor, + gen_fn, tracked_params) end - tracked_params -end - -function GFBackpropParamsState(trace::DynamicDSLTrace, tape, params, scale_factor) - tracked_params = track_params(tape, params) - splice_tracked_params = Dict{Tuple{GenerativeFunction,Symbol},Any}() - score = track(0., tape) - GFBackpropParamsState(trace, score, tape, AddressVisitor(), scale_factor, - tracked_params, splice_tracked_params) end -function read_param(state::GFBackpropParamsState, name::Symbol) - value = state.tracked_params[name] - value +function read_param!(state::GFBackpropParamsState, name::Symbol) + parameter_id = (state.active_gen_fn, name) + if !(parameter_id in state.trace.registered_julia_parameters) + throw(ArgumentError("parameter $parameter_id was not registered using register_parameters!")) + end + return state.tracked_params[parameter_id] end function traceat(state::GFBackpropParamsState, dist::Distribution{T}, @@ -110,30 +106,20 @@ end function splice(state::GFBackpropParamsState, gen_fn::DynamicDSLFunction, args_maybe_tracked::Tuple) - - # save previous tracked parameter scope - prev_tracked_params = state.tracked_params - - # construct new tracked parameter scope - state.tracked_params = Dict{Symbol,Any}() - for name in keys(gen_fn.params) - if haskey(state.splice_tracked_params, (gen_fn, name)) - # parameter was already tracked in another splice call - state.tracked_params[name] = state.splice_tracked_params[(gen_fn, name)] - else + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn + store = get_parameter_store(state.trace) + for (name, value) in get_local_parameters(store, gen_fn) + parameter_id = (gen_fn, name) + if !haskey(state.tracked_params, parameter_id) # parameter was not already tracked - tracked = track(get_param(gen_fn, name), state.tape) - state.tracked_params[name] = tracked - state.splice_tracked_params[(gen_fn, name)] = tracked + tracked = track(value, state.tape) + state.tracked_params[parameter_id] = tracked end end - retval_maybe_tracked = exec(gen_fn, state, args_maybe_tracked) - - # restore previous tracked parameter scope - state.tracked_params = prev_tracked_params - - retval_maybe_tracked + state.active_gen_fn = prev_gen_fn + return retval_maybe_tracked end @noinline function ReverseDiff.special_reverse_exec!( @@ -185,7 +171,7 @@ end function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_factor=1.) gen_fn = trace.gen_fn tape = new_tape() - state = GFBackpropParamsState(trace, tape, gen_fn.params, scale_factor) + state = GFBackpropParamsState(trace, tape, scale_factor) args = get_args(trace) args_maybe_tracked = (map(maybe_track, args, gen_fn.has_argument_grads, fill(tape, length(args)))...,) @@ -195,13 +181,10 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ reverse_pass!(tape) # increment the gradient accumulators for trainable parameters in scope - for (name, tracked) in state.tracked_params - gen_fn.params_grad[name] += deriv(tracked) * state.scale_factor - end - - # increment the gradient accumulators for trainable parameters in splice calls - for ((spliced_gen_fn, name), tracked) in state.splice_tracked_params - spliced_gen_fn.params_grad[name] += deriv(tracked) * state.scale_factor + store = get_parameter_store(trace) + for ((active_gen_fn, name), tracked) in state.tracked_params + parameter_id = (active_gen_fn, name) + increment_gradient!(parameter_id, deriv(tracked), state.scale_factor, store) end # return gradients with respect to arguments with gradients, or nothing @@ -211,30 +194,41 @@ function accumulate_param_gradients!(trace::DynamicDSLTrace, retval_grad, scale_ end -################## +#################### # choice_gradients # -################## +#################### mutable struct GFBackpropTraceState trace::DynamicDSLTrace score::TrackedReal tape::InstructionTape visitor::AddressVisitor - params::Dict{Symbol,Any} selection::Selection tracked_choices::Trie{Any,Union{TrackedReal,TrackedArray}} value_choices::DynamicChoiceMap gradient_choices::DynamicChoiceMap + active_gen_fn::GenerativeFunction end -function GFBackpropTraceState(trace, selection, params, tape) +function GFBackpropTraceState(trace, selection, tape) score = track(0., tape) visitor = AddressVisitor() tracked_choices = Trie{Any,Union{TrackedReal,TrackedArray}}() value_choices = choicemap() gradient_choices = choicemap() - GFBackpropTraceState(trace, score, tape, visitor, params, - selection, tracked_choices, value_choices, gradient_choices) + GFBackpropTraceState(trace, score, tape, visitor, + selection, tracked_choices, value_choices, gradient_choices, + get_gen_fn(trace)) +end + +get_parameter_store(state::GFBackpropTraceState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFBackpropTraceState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFBackpropTraceState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFBackpropTraceState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function fill_submaps!( @@ -324,15 +318,6 @@ function traceat(state::GFBackpropTraceState, gen_fn::GenerativeFunction{T,U}, retval_maybe_tracked end -function splice(state::GFBackpropTraceState, gen_fn::DynamicDSLFunction, - args_maybe_tracked::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args_maybe_tracked) - state.params = prev_params - retval -end - @noinline function ReverseDiff.special_reverse_exec!( instruction::ReverseDiff.SpecialInstruction{BackpropTraceRecord}) record = instruction.func @@ -372,7 +357,7 @@ end function choice_gradients(trace::DynamicDSLTrace, selection::Selection, retval_grad) gen_fn = trace.gen_fn tape = new_tape() - state = GFBackpropTraceState(trace, selection, gen_fn.params, tape) + state = GFBackpropTraceState(trace, selection, tape) args = get_args(trace) args_maybe_tracked = (map(maybe_track, args, gen_fn.has_argument_grads, fill(tape, length(args)))...,) diff --git a/src/dynamic/dynamic.jl b/src/dynamic/dynamic.jl index d83055444..37f4a337d 100644 --- a/src/dynamic/dynamic.jl +++ b/src/dynamic/dynamic.jl @@ -1,3 +1,5 @@ +export register_parameters! + include("trace.jl") """ @@ -8,15 +10,14 @@ A generative function based on a shallowly embedding modeling language based on Constructed using the `@gen` keyword. Most methods in the generative function interface involve a end-to-end execution of the function. """ -struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} - params_grad::Dict{Symbol,Any} - params::Dict{Symbol,Any} +mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace} arg_types::Vector{Type} has_defaults::Bool arg_defaults::Vector{Union{Some{Any},Nothing}} julia_function::Function has_argument_grads::Vector{Bool} accepts_output_grad::Bool + parameters::Union{Nothing,Set{Tuple{GenerativeFunction,Symbol}},Function} end function DynamicDSLFunction(arg_types::Vector{Type}, @@ -24,33 +25,97 @@ function DynamicDSLFunction(arg_types::Vector{Type}, julia_function::Function, has_argument_grads, ::Type{T}, accepts_output_grad::Bool) where {T} - params_grad = Dict{Symbol,Any}() - params = Dict{Symbol,Any}() has_defaults = any(arg -> arg != nothing, arg_defaults) - DynamicDSLFunction{T}(params_grad, params, arg_types, + return DynamicDSLFunction{T}(arg_types, has_defaults, arg_defaults, julia_function, - has_argument_grads, accepts_output_grad) + has_argument_grads, accepts_output_grad, + nothing) +end + +function get_parameters(gen_fn::DynamicDSLFunction, parameter_context) + if isa(gen_fn.parameters, Set) + julia_store = get_julia_store(parameter_context) + parameter_stores_to_ids = Dict{Any,Set}(julia_store => gen_fn.parameters) + return parameter_stores_to_ids + elseif isa(gen_fn.parameters, Function) + return gen_fn.parameters(parameter_context) + else + # no parameters were registered + return Dict{Any,Set}() + end +end + +""" + register_parameters!(gen_fn::DynamicDSLFunction, parameters) + +Register the trainable parameters that used by a DML generative function. + +This includes all parameters used within any calls made by the generative function, and includes any parameters that may be used by any possible trace (stochastic control flow may cause a parameter to be used by one trace but not another). + +The second argument is either an iterable collection or a `Function` that takes a parameter context and returns a `Dict` that maps parameter stores to `Set`s of parameter IDs. +When the second argument is an iterable collection, each element is either a `Symbol` that is the name of a parameter declared in the body of `gen_fn` using `@param`, or is a tuple `(other_gen_fn::GenerativeFunction, name::Symbol)` where `@param ` was declared in the body of `other_gen_fn`. +The `Function` input is used when `gen_fn` uses parameters that come from more than one parameter store, including parameters that are housed in parameter stores that are not `JuliaParameterStore`s (e.g. if `gen_fn` invokes a generative function that executes in another non-Julia runtime). +See [Optimizing Trainable Parameters](@ref) for details on parameter contexts, and parameter stores. +""" +function register_parameters!(gen_fn::DynamicDSLFunction, parameters::Function) + if gen_fn.parameters !== nothing + throw(ArgumentError("parameters for $gen_fn were already registered")) + end + gen_fn.parameters = parameters + return nothing end -function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction} +function register_parameters!(gen_fn::DynamicDSLFunction, parameters) + if gen_fn.parameters !== nothing + throw(ArgumentError("parameters for $gen_fn were already registered")) + end + gen_fn.parameters = Set{Tuple{GenerativeFunction,Symbol}}() + for param in parameters + if isa(param, Tuple{GenerativeFunction,Symbol}) + push!(gen_fn.parameters, param) + elseif isa(param, Symbol) + push!(gen_fn.parameters, (gen_fn, param)) + else + throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param")) + end + end + return nothing +end + +function Base.show(io::IO, gen_fn::DynamicDSLFunction) + print(io, "Gen DML generative function: $(gen_fn.julia_function)") +end + +function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction) + print(io, "Gen DML generative function: $(gen_fn.julia_function)") +end + +function DynamicDSLTrace( + gen_fn::T, args, parameter_store::JuliaParameterStore, + parameter_context, registered_julia_parameters) where {T<:DynamicDSLFunction} # pad args with default values, if available if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults) defaults = gen_fn.arg_defaults[length(args)+1:end] defaults = map(x -> something(x), defaults) args = Tuple(vcat(collect(args), defaults)) end - DynamicDSLTrace{T}(gen_fn, args) + return DynamicDSLTrace{T}( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) end accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad mutable struct GFUntracedState - params::Dict{Symbol,Any} + gen_fn::GenerativeFunction + parameter_store::JuliaParameterStore end +get_parameter_store(state::GFUntracedState) = state.parameter_store +get_parameter_id(state::GFUntracedState, name::Symbol) = (state.gen_fn, name) + function (gen_fn::DynamicDSLFunction)(args...) - state = GFUntracedState(gen_fn.params) + state = GFUntracedState(gen_fn, default_julia_parameter_store) gen_fn.julia_function(state, args...) end @@ -58,6 +123,14 @@ function exec(gen_fn::DynamicDSLFunction, state, args::Tuple) gen_fn.julia_function(state, args...) end +function splice(state, gen_fn::DynamicDSLFunction, args::Tuple) + prev_gen_fn = get_active_gen_fn(state) + state.active_gen_fn = gen_fn + retval = exec(gen_fn, state, args) + set_active_gen_fn!(state, prev_gen_fn) + return retval +end + # whether there is a gradient of score with respect to each argument # it returns 'nothing' for those arguemnts that don't have a derivatice has_argument_grads(gen::DynamicDSLFunction) = gen.has_argument_grads @@ -98,15 +171,13 @@ end function dynamic_param_impl(expr::Expr) @assert expr.head == :genparam "Not a Gen param expression." name = expr.args[1] - Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param), state, QuoteNode(name))) + Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param!), state, QuoteNode(name))) end -function read_param(state, name::Symbol) - if haskey(state.params, name) - state.params[name] - else - throw(UndefVarError(name)) - end +function read_param!(state, name::Symbol) + parameter_id = get_parameter_id(state, name) + store = get_parameter_store(state) + return get_parameter_value(parameter_id, store) end ################## diff --git a/src/dynamic/generate.jl b/src/dynamic/generate.jl index a89e0c352..866ca92f7 100644 --- a/src/dynamic/generate.jl +++ b/src/dynamic/generate.jl @@ -3,12 +3,26 @@ mutable struct GFGenerateState constraints::ChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFGenerateState(gen_fn, args, constraints, parameter_context) + parameter_store = get_julia_store(parameter_context) + registered_julia_parameters = get_parameters(gen_fn, parameter_context)[parameter_store] + trace = DynamicDSLTrace( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) + return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context) + end end -function GFGenerateState(gen_fn, args, constraints, params) - trace = DynamicDSLTrace(gen_fn, args) - GFGenerateState(trace, constraints, 0., AddressVisitor(), params) +get_parameter_store(state::GFGenerateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFGenerateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFGenerateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFGenerateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFGenerateState, dist::Distribution{T}, @@ -40,7 +54,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T}, state.weight += score end - retval + return retval end function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, @@ -55,7 +69,8 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, constraints = get_submap(state.constraints, key) # get subtrace - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate( + gen_fn, args, constraints, state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -66,22 +81,14 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U}, # get return value retval = get_retval(subtrace) - retval -end - -function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval + return retval end -function generate(gen_fn::DynamicDSLFunction, args::Tuple, - constraints::ChoiceMap) - state = GFGenerateState(gen_fn, args, constraints, gen_fn.params) +function generate( + gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap, + parameter_context::Dict) + state = GFGenerateState(gen_fn, args, constraints, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) - (state.trace, state.weight) + return (state.trace, state.weight) end diff --git a/src/dynamic/propose.jl b/src/dynamic/propose.jl index 7c630cc19..2642229f7 100644 --- a/src/dynamic/propose.jl +++ b/src/dynamic/propose.jl @@ -2,11 +2,23 @@ mutable struct GFProposeState choices::DynamicChoiceMap weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFProposeState( + gen_fn::GenerativeFunction, parameter_context) + return new(choicemap(), 0.0, AddressVisitor(), gen_fn, parameter_context) + end end -function GFProposeState(params::Dict{Symbol,Any}) - GFProposeState(choicemap(), 0., AddressVisitor(), params) +get_parameter_store(state::GFProposeState) = get_julia_store(state.parameter_context) + +get_parameter_id(state::GFProposeState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFProposeState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFProposeState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFProposeState, dist::Distribution{T}, @@ -25,7 +37,7 @@ function traceat(state::GFProposeState, dist::Distribution{T}, # update weight state.weight += logpdf(dist, retval, args...) - retval + return retval end function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, @@ -44,19 +56,11 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U}, # update weight state.weight += weight - retval -end - -function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval + return retval end -function propose(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFProposeState(gen_fn.params) +function propose(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict) + state = GFProposeState(gen_fn, parameter_context) retval = exec(gen_fn, state, args) - (state.choices, state.weight, retval) + return (state.choices, state.weight, retval) end diff --git a/src/dynamic/regenerate.jl b/src/dynamic/regenerate.jl index 13d14d86f..2c35251e3 100644 --- a/src/dynamic/regenerate.jl +++ b/src/dynamic/regenerate.jl @@ -4,14 +4,23 @@ mutable struct GFRegenerateState selection::Selection weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + + function GFRegenerateState(gen_fn, args, prev_trace, selection) + visitor = AddressVisitor() + trace = initialize_from(prev_trace, args) + return new(prev_trace, trace, selection, 0.0, visitor, gen_fn) + end end -function GFRegenerateState(gen_fn, args, prev_trace, - selection, params) - visitor = AddressVisitor() - GFRegenerateState(prev_trace, DynamicDSLTrace(gen_fn, args), selection, - 0., visitor, params) +get_parameter_store(state::GFRegenerateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFRegenerateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFRegenerateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFRegenerateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFRegenerateState, dist::Distribution{T}, @@ -77,7 +86,8 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, (subtrace, weight, _) = regenerate( prev_subtrace, args, map((_) -> UnknownChange(), args), subselection) else - (subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap()) + (subtrace, weight) = generate( + gen_fn, args, EmptyChoiceMap(), state.trace.parameter_context) end # update weight @@ -92,15 +102,6 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFRegenerateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval -end - function regenerate_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, visited::EmptySelection) noise = 0. @@ -133,7 +134,7 @@ end function regenerate(trace::DynamicDSLTrace, args::Tuple, argdiffs::Tuple, selection::Selection) gen_fn = trace.gen_fn - state = GFRegenerateState(gen_fn, args, trace, selection, gen_fn.params) + state = GFRegenerateState(gen_fn, args, trace, selection) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) visited = state.visitor.visited diff --git a/src/dynamic/simulate.jl b/src/dynamic/simulate.jl index 7db1a213a..bc75b1ada 100644 --- a/src/dynamic/simulate.jl +++ b/src/dynamic/simulate.jl @@ -1,12 +1,27 @@ mutable struct GFSimulateState trace::DynamicDSLTrace visitor::AddressVisitor - params::Dict{Symbol,Any} + active_gen_fn::DynamicDSLFunction # mutated by splicing + parameter_context::Dict + + function GFSimulateState( + gen_fn::GenerativeFunction, args::Tuple, parameter_context) + parameter_store = get_julia_store(parameter_context) + registered_julia_parameters = get_parameters(gen_fn, parameter_context)[parameter_store] + trace = DynamicDSLTrace( + gen_fn, args, parameter_store, parameter_context, registered_julia_parameters) + return new(trace, AddressVisitor(), gen_fn, parameter_context) + end end -function GFSimulateState(gen_fn::GenerativeFunction, args::Tuple, params) - trace = DynamicDSLTrace(gen_fn, args) - GFSimulateState(trace, AddressVisitor(), params) +get_parameter_store(state::GFSimulateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFSimulateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFSimulateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFSimulateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFSimulateState, dist::Distribution{T}, @@ -36,7 +51,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, visit!(state.visitor, key) # get subtrace - subtrace = simulate(gen_fn, args) + subtrace = simulate(gen_fn, args, state.parameter_context) # add to the trace add_call!(state.trace, key, subtrace) @@ -47,18 +62,9 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFSimulateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params - retval = exec(gen_fn, state, args) - state.params = prev_params - retval -end - -function simulate(gen_fn::DynamicDSLFunction, args::Tuple) - state = GFSimulateState(gen_fn, args, gen_fn.params) +function simulate(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict) + state = GFSimulateState(gen_fn, args, parameter_context) retval = exec(gen_fn, state, args) set_retval!(state.trace, retval) - state.trace + return state.trace end diff --git a/src/dynamic/trace.jl b/src/dynamic/trace.jl index 8c02eceb5..ea13135e6 100644 --- a/src/dynamic/trace.jl +++ b/src/dynamic/trace.jl @@ -37,14 +37,30 @@ mutable struct DynamicDSLTrace{T} <: Trace score::Float64 noise::Float64 args::Tuple + parameter_store::JuliaParameterStore + parameter_context::Dict + registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}} # for runtime cross-check retval::Any - function DynamicDSLTrace{T}(gen_fn::T, args) where {T} + function DynamicDSLTrace{T}( + gen_fn::T, args, parameter_store::JuliaParameterStore, parameter_context, + registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}}) where {T} trie = Trie{Any,ChoiceOrCallRecord}() # retval is not known yet - new(gen_fn, trie, true, 0, 0, args) + new( + gen_fn, trie, true, 0, 0, args, parameter_store, + parameter_context, registered_julia_parameters) end end +function initialize_from(other::DynamicDSLTrace, args) + gen_fn = get_gen_fn(other) + return DynamicDSLTrace( + gen_fn, args, other.parameter_store, other.parameter_context, + other.registered_julia_parameters) +end + +get_parameter_store(trace::DynamicDSLTrace) = trace.parameter_store + set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval) function has_choice(trace::DynamicDSLTrace, addr) diff --git a/src/dynamic/update.jl b/src/dynamic/update.jl index bf4a42de2..6085fe712 100644 --- a/src/dynamic/update.jl +++ b/src/dynamic/update.jl @@ -4,16 +4,26 @@ mutable struct GFUpdateState constraints::Any weight::Float64 visitor::AddressVisitor - params::Dict{Symbol,Any} discard::DynamicChoiceMap + active_gen_fn::DynamicDSLFunction # mutated by splicing + + function GFUpdateState(gen_fn, args, prev_trace, constraints) + visitor = AddressVisitor() + discard = choicemap() + parameter_store = get_parameter_store(prev_trace) + trace = initialize_from(prev_trace, args) + return new(prev_trace, trace, constraints, 0.0, visitor, discard, gen_fn) + end end -function GFUpdateState(gen_fn, args, prev_trace, constraints, params) - visitor = AddressVisitor() - discard = choicemap() - trace = DynamicDSLTrace(gen_fn, args) - GFUpdateState(prev_trace, trace, constraints, - 0., visitor, params, discard) +get_parameter_store(state::GFUpdateState) = get_parameter_store(state.trace) + +get_parameter_id(state::GFUpdateState, name::Symbol) = (state.active_gen_fn, name) + +get_active_gen_fn(state::GFUpdateState) = state.active_gen_fn + +function set_active_gen_fn!(state::GFUpdateState, gen_fn::GenerativeFunction) + state.active_gen_fn = gen_fn end function traceat(state::GFUpdateState, dist::Distribution{T}, @@ -90,7 +100,8 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, (subtrace, weight, _, discard) = update(prev_subtrace, args, map((_) -> UnknownChange(), args), constraints) else - (subtrace, weight) = generate(gen_fn, args, constraints) + (subtrace, weight) = generate( + gen_fn, args, constraints, state.trace.parameter_context) end # update the weight @@ -110,13 +121,13 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U}, retval end -function splice(state::GFUpdateState, gen_fn::DynamicDSLFunction, - args::Tuple) - prev_params = state.params - state.params = gen_fn.params +function splice( + state::GFUpdateState, gen_fn::DynamicDSLFunction, args::Tuple) + prev_gen_fn = state.active_gen_fn + state.active_gen_fn = gen_fn retval = exec(gen_fn, state, args) - state.params = prev_params - retval + state.active_gen_fn = prev_gen_fn + return retval end function update_delete_recurse(prev_trie::Trie{Any,ChoiceOrCallRecord}, @@ -187,7 +198,7 @@ end function update(trace::DynamicDSLTrace, arg_values::Tuple, arg_diffs::Tuple, constraints::ChoiceMap) gen_fn = trace.gen_fn - state = GFUpdateState(gen_fn, arg_values, trace, constraints, gen_fn.params) + state = GFUpdateState(gen_fn, arg_values, trace, constraints) retval = exec(gen_fn, state, arg_values) set_retval!(state.trace, retval) visited = get_visited(state.visitor) diff --git a/src/gen_fn_interface.jl b/src/gen_fn_interface.jl index f3881f43f..54ce4e403 100644 --- a/src/gen_fn_interface.jl +++ b/src/gen_fn_interface.jl @@ -1,3 +1,27 @@ +export GenerativeFunction +export has_argument_grads +export accepts_output_grad + + +export get_parameters + +export Trace +export get_args +export get_retval +export get_choices +export get_score +export get_gen_fn + +export simulate +export generate +export project +export propose +export assess +export update +export regenerate +export accumulate_param_gradients! +export choice_gradients + ########## # Traces # ########## @@ -85,12 +109,6 @@ Synonym for [`get_retval`](@ref). """ Base.getindex(trace::Trace) = get_retval(trace) -export get_args -export get_retval -export get_choices -export get_score -export get_gen_fn - ###################### # GenerativeFunction # ###################### @@ -105,6 +123,20 @@ abstract type GenerativeFunction{T,U <: Trace} end get_return_type(::GenerativeFunction{T,U}) where {T,U} = T get_trace_type(::GenerativeFunction{T,U}) where {T,U} = U +""" + parameters::Dict{Any,Set} = get_parameters( + gen_fn::GenerativeFunction, + parameter_context=default_parameter_context) + +Returns the parameters used by the generative function (including all of its calls). + +The parameters are returned in a `Dict` that maps parameter stores to sets of +parameter IDs within each store. +""" +function get_parameters(gen_fn::GenerativeFunction, parameter_context) + return Dict{Any,Set}() +end + """ bools::Tuple = has_argument_grads(gen_fn::Union{GenerativeFunction,Distribution}) @@ -135,7 +167,7 @@ Return an iterable over the trainable parameters of the generative function. get_params(::GenerativeFunction) = () """ - trace = simulate(gen_fn, args) + trace = simulate(gen_fn, args, parameter_context=default_parameter_context) Execute the generative function and return the trace. @@ -146,20 +178,17 @@ If `gen_fn` has optional trailing arguments (i.e., default values are provided), the optional arguments can be omitted from the `args` tuple. The generated trace will have default values filled in. """ -function simulate(::GenerativeFunction, ::Tuple) +function simulate(::GenerativeFunction, ::Tuple, parameter_context::Dict) error("Not implemented") end """ - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple) - -Return a trace of a generative function. - - (trace::U, weight) = generate(gen_fn::GenerativeFunction{T,U}, args::Tuple, - constraints::ChoiceMap) + (trace::U, weight) = generate( + gen_fn::GenerativeFunction{T,U}, args::Tuple, + constraints=EmptyChoiceMap(), parameter_context=default_parameter_context) Return a trace of a generative function that is consistent with the given -constraints on the random choices. +constraints on the random choices, if any. Given arguments \$x\$ (`args`) and assignment \$u\$ (`constraints`) (which is empty for the first form), sample \$t \\sim q(\\cdot; u, x)\$ and \$r \\sim q(\\cdot; x, t)\$, and return the trace \$(x, t, r)\$ (`trace`). @@ -182,13 +211,10 @@ Example with constraint that address `:z` takes value `true`. (trace, weight) = generate(foo, (2, 4), choicemap((:z, true)) ``` """ -function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap) +function generate(::GenerativeFunction, ::Tuple, ::ChoiceMap, parameter_context::Dict) error("Not implemented") end -function generate(gen_fn::GenerativeFunction, args::Tuple) - generate(gen_fn, args, EmptyChoiceMap()) -end """ weight = project(trace::U, selection::Selection) @@ -207,8 +233,10 @@ function project(trace, selection::Selection) error("Not implemented") end + """ - (choices, weight, retval) = propose(gen_fn::GenerativeFunction, args::Tuple) + (choices, weight, retval) = propose( + gen_fn::GenerativeFunction, args::Tuple, parameter_context=default_parameter_context) Sample an assignment and compute the probability of proposing that assignment. @@ -219,14 +247,17 @@ t)\$, and return \$t\$ \\log \\frac{p(r, t; x)}{q(r; x, t)} ``` """ -function propose(gen_fn::GenerativeFunction, args::Tuple) - trace = simulate(gen_fn, args) +function propose(gen_fn::GenerativeFunction, args::Tuple, parameter_context::Dict) + trace = simulate(gen_fn, args, parameter_context) weight = get_score(trace) - (get_choices(trace), weight, get_retval(trace)) + return (get_choices(trace), weight, get_retval(trace)) end + """ - (weight, retval) = assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + (weight, retval) = assess( + gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap, + parameter_context=default_parameter_context) Return the probability of proposing an assignment @@ -238,14 +269,16 @@ return the weight (`weight`): ``` It is an error if \$p(t; x) = 0\$. """ -function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) - (trace, weight) = generate(gen_fn, args, choices) - (weight, get_retval(trace)) +function assess( + gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap, parameter_context::Dict) + (trace, weight) = generate(gen_fn, args, choices, parameter_context) + return (weight, get_retval(trace)) end + """ - (new_trace, weight, retdiff, discard) = update(trace, args::Tuple, argdiffs::Tuple, - constraints::ChoiceMap) + (new_trace, weight, retdiff, discard) = update( + trace, args::Tuple, argdiffs::Tuple, constraints::ChoiceMap) Update a trace by changing the arguments and/or providing new values for some existing random choice(s) and values for some newly introduced random choice(s). @@ -290,8 +323,8 @@ function update(trace, constraints::ChoiceMap) end """ - (new_trace, weight, retdiff) = regenerate(trace, args::Tuple, argdiffs::Tuple, - selection::Selection) + (new_trace, weight, retdiff) = regenerate( + trace, args::Tuple, argdiffs::Tuple, selection::Selection) Update a trace by changing the arguments and/or randomly sampling new values for selected random choices using the internal proposal distribution family. @@ -408,18 +441,3 @@ end function choice_gradients(trace) choice_gradients(trace, EmptySelection(), nothing) end - -export GenerativeFunction -export Trace -export has_argument_grads -export accepts_output_grad -export get_params -export simulate -export generate -export project -export propose -export assess -export update -export regenerate -export accumulate_param_gradients! -export choice_gradients diff --git a/src/inference/importance.jl b/src/inference/importance.jl index fac3c0e4d..c0103c5b4 100644 --- a/src/inference/importance.jl +++ b/src/inference/importance.jl @@ -17,14 +17,20 @@ The second variant uses a custom proposal distribution defined by the given gene All addresses of random choices sampled by the proposal should also be sampled by the model function. Setting `verbose=true` prints a progress message every sample. """ -function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, - observations::ChoiceMap, - num_samples::Int, verbose=false) where {T,U} +function importance_sampling( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, num_samples::Int; + verbose=false, multithreaded=false) where {T,U} traces = Vector{U}(undef, num_samples) log_weights = Vector{Float64}(undef, num_samples) - for i=1:num_samples - verbose && println("sample: $i of $num_samples") - (traces[i], log_weights[i]) = generate(model, model_args, observations) + if multithreaded + Threads.@threads for i in 1:num_samples + importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose) + end + else + for i=1:num_samples + importance_sampling_iter!(traces, log_weights, model, model_args, observations, i, verbose) + end end log_total_weight = logsumexp(log_weights) log_ml_estimate = log_total_weight - log(num_samples) @@ -32,18 +38,33 @@ function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, return (traces, log_normalized_weights, log_ml_estimate) end -function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, - observations::ChoiceMap, - proposal::GenerativeFunction, proposal_args::Tuple, - num_samples::Int, verbose=false) where {T,U} +function importance_sampling_iter!( + traces::Vector, log_weights::Vector{Float64}, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, i::Int, verbose::Bool) + (traces[i], log_weights[i]) = generate(model, model_args, observations) + verbose && Core.println("sample: $i of $num_samples completed in thread $(Threads.threadid())") + return nothing +end + +function importance_sampling( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, proposal::GenerativeFunction, proposal_args::Tuple, + num_samples::Int; verbose=false, multithreaded=false) where {T,U} traces = Vector{U}(undef, num_samples) log_weights = Vector{Float64}(undef, num_samples) - for i=1:num_samples - verbose && println("sample: $i of $num_samples") - (proposed_choices, proposal_weight, _) = propose(proposal, proposal_args) - constraints = merge(observations, proposed_choices) - (traces[i], model_weight) = generate(model, model_args, constraints) - log_weights[i] = model_weight - proposal_weight + if multithreaded + Threads.@threads for i=1:num_samples + importance_sampling_iter!( + traces, log_weights, model, model_args, + observations, proposal, proposal_args, i, verbose) + end + else + for i=1:num_samples + importance_sampling_iter!( + traces, log_weights, model, model_args, + observations, proposal, proposal_args, i, verbose) + end end log_total_weight = logsumexp(log_weights) log_ml_estimate = log_total_weight - log(num_samples) @@ -51,6 +72,20 @@ function importance_sampling(model::GenerativeFunction{T,U}, model_args::Tuple, return (traces, log_normalized_weights, log_ml_estimate) end +function importance_sampling_iter!( + traces::Vector, log_weights::Vector{Float64}, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, + proposal::GenerativeFunction, proposal_args::Tuple, + i::Int, verbose::Bool) + (proposed_choices, proposal_weight, _) = propose(proposal, proposal_args) + constraints = merge(observations, proposed_choices) + (traces[i], model_weight) = generate(model, model_args, constraints) + log_weights[i] = model_weight - proposal_weight + verbose && Core.println("sample $i of $num_samples completed in thread $(Threads.threadid())") + return nothing +end + """ (trace, lml_est) = importance_resampling(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, num_samples::Int, diff --git a/src/inference/train.jl b/src/inference/train.jl index 9e64869cd..8c2bf8828 100644 --- a/src/inference/train.jl +++ b/src/inference/train.jl @@ -1,7 +1,8 @@ """ train!(gen_fn::GenerativeFunction, data_generator::Function, - update::ParamUpdate, - num_epoch, epoch_size, num_minibatch, minibatch_size; verbose::Bool=false) + optimizer; + num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1; + verbose::Bool=false) Train the given generative function to maximize the expected conditional log probability (density) that `gen_fn` generates the assignment `constraints` @@ -22,7 +23,7 @@ taken under the marginal distribution on `inputs` determined by the data generator. """ function train!(gen_fn::GenerativeFunction, data_generator::Function, - update::ParamUpdate; + optimizer; num_epoch=1, epoch_size=1, num_minibatch=1, minibatch_size=1, evaluation_size=epoch_size, verbose=false, callback=(epoch, minibatch, minibatch_objective) -> nothing) @@ -55,7 +56,7 @@ function train!(gen_fn::GenerativeFunction, data_generator::Function, minibatch_objective += weight accumulate_param_gradients!(trace) end - apply!(update) + apply_update!(optimizer) minibatch_objective /= minibatch_size callback(epoch, minibatch, minibatch_objective) end @@ -87,7 +88,7 @@ end p::GenerativeFunction, p_args::Tuple, q::GenerativeFunction, get_q_args::Function) -Simulate a trace of p representing a training example, and use to update the gradients of the trainable parameters of q. +Simulate a trace of p representing a training example, and use to optimizer the gradients of the trainable parameters of q. Used for training q via maximum expected conditional likelihood. Random choices will be mapped from p to q based on their address. @@ -101,7 +102,7 @@ function lecture!( q_args = get_q_args(p_trace) q_trace, score = generate(q, q_args, get_choices(p_trace)) # NOTE: q won't make all the random choices that p does accumulate_param_gradients!(q_trace) - score + return score end """ @@ -109,7 +110,7 @@ end p::GenerativeFunction, p_args::Tuple, q::GenerativeFunction, get_q_args::Function) -Simulate a batch of traces of p representing training samples, and use them to update the gradients of the trainable parameters of q. +Simulate a batch of traces of p representing training samples, and use them to optimizer the gradients of the trainable parameters of q. Like `lecture!` but q is batched, and must make random choices for training sample i under hierarchical address namespace i::Int (e.g. i => :z). get_q_args maps a vector of traces of p to an argument tuple of q. @@ -128,7 +129,7 @@ function lecture_batched!( q_args = get_q_args(p_traces) q_trace, score = generate(q_batched, q_args, constraints) # NOTE: q won't make all the random choices that p does accumulate_param_gradients!(q_trace) - score / batch_size + return score / batch_size end export train! diff --git a/src/inference/variational.jl b/src/inference/variational.jl index 88736bf8b..09713cc7c 100644 --- a/src/inference/variational.jl +++ b/src/inference/variational.jl @@ -16,7 +16,7 @@ function single_sample_gradient_estimate!( accumulate_param_gradients!(var_trace, nothing, log_weight * scale_factor) # unbiased estimate of objective function, and trace - (log_weight, var_trace, model_trace) + return (log_weight, var_trace, model_trace) end function vimco_geometric_baselines(log_weights) @@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights) baselines[i] = logsumexp(log_weights) - log(num_samples) log_weights[i] = temp end - baselines + return baselines end function logdiffexp(x, y) m = max(x, y) - m + log(exp(x - m) - exp(y - m)) + return m + log(exp(x - m) - exp(y - m)) end function vimco_arithmetic_baselines(log_weights) @@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights) log_f_hat = log_sum_f_without_i - log(num_samples - 1) baselines[i] = logsumexp(log_sum_f_without_i, log_f_hat) - log(num_samples) end - baselines + return baselines end # black box, VIMCO gradient estimator @@ -85,28 +85,28 @@ function multi_sample_gradient_estimate!( # collection of traces and normalized importance weights, and estimate of # objective function - (L, traces, weights_normalized) + return (L, traces, weights_normalized) end -function _maybe_accumulate_param_grad!(trace, update::ParamUpdate, scale_factor::Real) +function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real) return accumulate_param_gradients!(trace, nothing, scale_factor) end -function _maybe_accumulate_param_grad!(trace, update::Nothing, scale_factor::Real) +function _maybe_accumulate_param_grad!(trace, optimizer::Nothing, scale_factor::Real) end """ (elbo_estimate, traces, elbo_history) = black_box_vi!( model::GenerativeFunction, model_args::Tuple, - [model_update::ParamUpdate,] + [model_optimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; + var_model_optimizer; options...) Fit the parameters of a variational model (`var_model`) to the posterior distribution implied by the given `model` and `observations` using stochastic -gradient methods. Users may optionally specify a `model_update` to jointly +gradient methods. Users may optionally specify a `model_optimizer` to jointly update the parameters of `model`. # Additional arguments: @@ -117,39 +117,41 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vi!( model::GenerativeFunction, model_args::Tuple, - model_update::Union{ParamUpdate,Nothing}, + model_optimizer, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; + var_model_optimizer; iters=1000, samples_per_iter=100, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) elbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - elbo_estimate = 0.0 - # TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe) - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, 1/samples_per_iter) - elbo_estimate += (log_weight / samples_per_iter) - - # accumulate the generative model gradients - _maybe_accumulate_param_grad!(model_trace, model_update, 1.0 / samples_per_iter) - - # record the traces - var_traces[sample] = var_trace - model_traces[sample] = model_trace + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end end + elbo_estimate = sum(log_weights) elbo_history[iter] = elbo_estimate # print it @@ -159,38 +161,63 @@ function black_box_vi!( callback(iter, var_traces, elbo_estimate) # update parameters of variational family - apply!(var_model_update) + apply_update!(var_model_optimizer) # update parameters of generative model - if !isnothing(model_update) - apply!(model_update) + if !isnothing(model_optimizer) + apply_update!(model_optimizer) end end - (elbo_history[end], var_traces, elbo_history, model_traces) + return (elbo_history[end], var_traces, elbo_history, model_traces) +end + +function black_box_vi_iter!( + var_traces::Vector, model_traces::Vector, log_weights::Vector{Float64}, + i::Int, n::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, + model_optimizer) + + # accumulate the variational family gradients + (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, 1.0 / n) + log_weights[i] = log_weight / n + + # accumulate the generative model gradients + _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / n) + + # record the traces + var_traces[i] = var_trace + model_traces[i] = model_trace + + return nothing end + black_box_vi!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate; options...) = + var_model_optimizer; options...) = black_box_vi!(model, model_args, nothing, observations, - var_model, var_model_args, var_model_update; options...) + var_model, var_model_args, var_model_optimizer; options...) """ (iwelbo_estimate, traces, iwelbo_history) = black_box_vimco!( model::GenerativeFunction, model_args::Tuple, - [model_update::ParamUpdate,] + [model_optimizer,] observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; options...) Fit the parameters of a variational model (`var_model`) to the posterior distribution implied by the given `model` and `observations` using stochastic gradient methods applied to the [Variational Inference with Monte Carlo Objectives](https://arxiv.org/abs/1602.06725) (VIMCO) lower bound on the -marginal likelihood. Users may optionally specify a `model_update` to jointly +marginal likelihood. Users may optionally specify a `model_optimizer` to jointly update the parameters of `model`. # Additional arguments: @@ -205,42 +232,45 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vimco!( model::GenerativeFunction, model_args::Tuple, - model_update::Union{ParamUpdate,Nothing}, observations::ChoiceMap, + model_optimizer::Union{CompositeOptimizer,Nothing}, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, grad_est_samples::Int; + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; iters=1000, samples_per_iter=100, geometric=true, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) resampled_var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) iwelbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - iwelbo_estimate = 0. - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (est, original_var_traces, weights) = multi_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, grad_est_samples, - 1/samples_per_iter, geometric) - iwelbo_estimate += (est / samples_per_iter) - - # record a variational trace obtained by resampling from the weighted collection - resampled_var_traces[sample] = original_var_traces[categorical(weights)] - - # accumulate the generative model gradient estimator - for (var_trace, weight) in zip(original_var_traces, weights) - constraints = merge(observations, get_choices(var_trace)) - (model_trace, _) = generate(model, model_args, constraints) - _maybe_accumulate_param_grad!(model_trace, model_update, weight / samples_per_iter) + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) end end + iwelbo_estimate = sum(log_weights) iwelbo_history[iter] = iwelbo_estimate # print it @@ -250,11 +280,11 @@ function black_box_vimco!( callback(iter, resampled_var_traces, iwelbo_estimate) # update parameters of variational family - apply!(var_model_update) + apply_update!(var_model_optimizer) # update parameters of generative model - if !isnothing(model_update) - apply!(model_update) + if !isnothing(model_optimizer) + apply_update!(model_optimizer) end end @@ -262,13 +292,41 @@ function black_box_vimco!( (iwelbo_history[end], resampled_var_traces, iwelbo_history, model_traces) end +function black_box_vimco_iter!( + resampled_var_traces::Vector, log_weights::Vector{Float64}, + i::Int, samples_per_iter::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, geometric::Bool, grad_est_samples::Int, + model_optimizer) + + # accumulate the variational family gradients + (est, original_var_traces, weights) = multi_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, grad_est_samples, + 1/samples_per_iter, geometric) + log_weights[i] = est / samples_per_iter + + # record a variational trace obtained by resampling from the weighted collection + resampled_var_traces[i] = original_var_traces[categorical(weights)] + + # accumulate the generative model gradient estimator + for (var_trace, weight) in zip(original_var_traces, weights) + constraints = merge(observations, get_choices(var_trace)) + (model_trace, _) = generate(model, model_args, constraints) + _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) + end + + return nothing +end + black_box_vimco!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, - var_model_update::ParamUpdate, + var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; options...) = black_box_vimco!(model, model_args, nothing, observations, - var_model, var_model_args, var_model_update, + var_model, var_model_args, var_model_optimizer, grad_est_samples; options...) export black_box_vi!, black_box_vimco! diff --git a/src/modeling_library/call_at/call_at.jl b/src/modeling_library/call_at/call_at.jl index 0c5b997bc..ac2393f47 100644 --- a/src/modeling_library/call_at/call_at.jl +++ b/src/modeling_library/call_at/call_at.jl @@ -64,6 +64,8 @@ function accepts_output_grad(gen_fn::CallAtCombinator) accepts_output_grad(gen_fn.kernel) end +get_parameters(gen_fn::CallAtCombinator, context) = get_parameters(gen_fn.kernel, context) + unpack_call_at_args(args) = (args[end], args[1:end-1]) @@ -89,13 +91,14 @@ function simulate(gen_fn::CallAtCombinator, args::Tuple) CallAtTrace(gen_fn, subtrace, key) end -function generate(gen_fn::CallAtCombinator{T,U,K}, args::Tuple, - choices::ChoiceMap) where {T,U,K} +function generate( + gen_fn::CallAtCombinator{T,U,K}, args::Tuple, + choices::ChoiceMap, parameter_context::Dict) where {T,U,K} (key, kernel_args) = unpack_call_at_args(args) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) trace = CallAtTrace(gen_fn, subtrace, key) - (trace, weight) + return (trace, weight) end function project(trace::CallAtTrace, selection::Selection) diff --git a/src/modeling_library/custom_determ.jl b/src/modeling_library/custom_determ.jl index 24d6d90f2..f5fadf641 100644 --- a/src/modeling_library/custom_determ.jl +++ b/src/modeling_library/custom_determ.jl @@ -93,12 +93,16 @@ function accumulate_param_gradients_determ!( gradient_with_state(gen_fn, state, args, retgrad) end -function simulate(gen_fn::CustomDetermGF{T,S}, args::Tuple) where {T,S} +function simulate( + gen_fn::CustomDetermGF{T,S}, args::Tuple, + parameter_context::Dict) where {T,S} retval, state = apply_with_state(gen_fn, args) CustomDetermGFTrace{T,S}(retval, state, args, gen_fn) end -function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) where {T,S} +function generate( + gen_fn::CustomDetermGF{T,S}, args::Tuple, + choices::ChoiceMap, parameter_context::Dict) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") end @@ -107,7 +111,9 @@ function generate(gen_fn::CustomDetermGF{T,S}, args::Tuple, choices::ChoiceMap) trace, 0. end -function update(trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, choices::ChoiceMap) where {T,S} +function update( + trace::CustomDetermGFTrace{T,S}, args::Tuple, argdiffs::Tuple, + choices::ChoiceMap) where {T,S} if !isempty(choices) error("Deterministic generative function makes no random choices") end @@ -129,7 +135,8 @@ function accumulate_param_gradients!(trace::CustomDetermGFTrace, retgrad, scale_ accumulate_param_gradients_determ!(trace.gen_fn, trace.state, trace.args, retgrad, scale_factor) end -export CustomDetermGF, CustomDetermGFTrace, apply_with_state, update_with_state, gradient_with_state, accumulate_param_gradients_determ! +export CustomDetermGF, CustomDetermGFTrace, apply_with_state, update_with_state +export gradient_with_state, accumulate_param_gradients_determ! #################### # CustomGradientGF # diff --git a/src/modeling_library/map/assess.jl b/src/modeling_library/map/assess.jl index 74538d9d4..1d4e62a16 100644 --- a/src/modeling_library/map/assess.jl +++ b/src/modeling_library/map/assess.jl @@ -3,20 +3,24 @@ mutable struct MapAssessState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, - key::Int, state::MapAssessState{T}) where {T,U} +function process_new!( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + key::Int, state::MapAssessState{T}, + parameter_context) where {T,U} kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) - (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) + (weight, retval) = assess(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.retvals[key] = retval end -function assess(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function assess( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = length(args[1]) state = MapAssessState{T}(0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, choices, key, state) + process_new!(gen_fn, args, choices, key, state, parameter_context) end - (state.weight, PersistentVector{T}(state.retvals)) + return (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/generate.jl b/src/modeling_library/map/generate.jl index 8a5c2e8da..cc886a4a6 100644 --- a/src/modeling_library/map/generate.jl +++ b/src/modeling_library/map/generate.jl @@ -7,31 +7,35 @@ mutable struct MapGenerateState{T,U} num_nonempty::Int end -function process!(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, - key::Int, state::MapGenerateState{T,U}) where {T,U} +function process!( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + key::Int, state::MapGenerateState{T,U}, + parameter_context) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) state.subtraces[key] = subtrace retval = get_retval(subtrace) - state.retval[key] = retval + return state.retval[key] = retval end -function generate(gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate( + gen_fn::Map{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = length(args[1]) state = MapGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) # TODO check for keys that aren't valid constraints for key=1:len - process!(gen_fn, args, choices, key, state) + process!(gen_fn, args, choices, key, state, parameter_context) end trace = VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), args, state.score, state.noise, len, state.num_nonempty) - (trace, state.weight) + return (trace, state.weight) end diff --git a/src/modeling_library/map/map.jl b/src/modeling_library/map/map.jl index 1bb695ff7..6335dc7ec 100644 --- a/src/modeling_library/map/map.jl +++ b/src/modeling_library/map/map.jl @@ -26,6 +26,7 @@ export Map has_argument_grads(map_gf::Map) = has_argument_grads(map_gf.kernel) accepts_output_grad(map_gf::Map) = accepts_output_grad(map_gf.kernel) +get_parameters(map_gf::Map, parameter_context) = get_parameters(map_gf.kernel, parameter_context) function (gen_fn::Map)(args...) (_, _, retval) = propose(gen_fn, args) diff --git a/src/modeling_library/map/propose.jl b/src/modeling_library/map/propose.jl index c0ca14330..f6929a13a 100644 --- a/src/modeling_library/map/propose.jl +++ b/src/modeling_library/map/propose.jl @@ -4,22 +4,25 @@ mutable struct MapProposeState{T} retvals::Vector{T} end -function process_new!(gen_fn::Map{T,U}, args::Tuple, key::Int, - state::MapProposeState{T}) where {T,U} +function process_new!( + gen_fn::Map{T,U}, args::Tuple, key::Int, + state::MapProposeState{T}, + parameter_context) where {T,U} local subtrace::U kernel_args = get_args_for_key(args, key) - (submap, weight, retval) = propose(gen_fn.kernel, kernel_args) + (submap, weight, retval) = propose(gen_fn.kernel, kernel_args, parameter_context) set_submap!(state.choices, key, submap) state.weight += weight state.retvals[key] = retval end -function propose(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function propose( + gen_fn::Map{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = length(args[1]) choices = choicemap() state = MapProposeState{T}(choices, 0., Vector{T}(undef,len)) for key=1:len - process_new!(gen_fn, args, key, state) + process_new!(gen_fn, args, key, state, parameter_context) end - (state.choices, state.weight, PersistentVector{T}(state.retvals)) + return (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/map/simulate.jl b/src/modeling_library/map/simulate.jl index 216e140c8..610ce0db1 100644 --- a/src/modeling_library/map/simulate.jl +++ b/src/modeling_library/map/simulate.jl @@ -7,11 +7,12 @@ mutable struct MapSimulateState{T,U} end function process!(gen_fn::Map{T,U}, args::Tuple, - key::Int, state::MapSimulateState{T,U}) where {T,U} + key::Int, state::MapSimulateState{T,U}, + parameter_context) where {T,U} local subtrace::U local retval::T kernel_args = get_args_for_key(args, key) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(gen_fn.kernel, kernel_args, parameter_context) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -20,11 +21,11 @@ function process!(gen_fn::Map{T,U}, args::Tuple, state.retval[key] = retval end -function simulate(gen_fn::Map{T,U}, args::Tuple) where {T,U} +function simulate(gen_fn::Map{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = length(args[1]) state = MapSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0) for key=1:len - process!(gen_fn, args, key, state) + process!(gen_fn, args, key, state, parameter_context) end VectorTrace{MapType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/recurse/recurse.jl b/src/modeling_library/recurse/recurse.jl index 7b4d1d23d..15339fb7b 100644 --- a/src/modeling_library/recurse/recurse.jl +++ b/src/modeling_library/recurse/recurse.jl @@ -131,6 +131,13 @@ end # TODO accepts_output_grad(::Recurse) = false +function get_parameters(gen_fn::Recurse, parameter_context) + parameter_stores_to_ids = Dict{Any,Vector}() + merge!(parameter_stores_to_ids, get_parameters(gen_fn.production_kernel, parameter_context)) + merge!(parameter_stores_to_ids, get_parameters(gen_fn.aggregation_kernel, parameter_context)) + return parameter_stores_to_ids +end + function (gen_fn::Recurse)(args...) (_, _, retval) = propose(gen_fn, args) retval @@ -197,7 +204,9 @@ end # simulate # ############ -function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T,U,V,W,X,Y} +function simulate( + gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, + parameter_context::Dict) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() aggregation_traces = PersistentHashMap{Int,T}() @@ -213,7 +222,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T cur = first(prod_to_visit) delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) - subtrace = simulate(gen_fn.production_kern, (input,)) + subtrace = simulate(gen_fn.production_kern, (input,), parameter_context) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) children_inputs::Vector{U} = get_retval(subtrace).children @@ -232,7 +241,7 @@ function simulate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}) where {S,T local subtrace::T local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) - subtrace = simulate(gen_fn.aggregation_kern, input) + subtrace = simulate(gen_fn.aggregation_kern, input, parameter_context) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) if !isempty(get_choices(subtrace)) @@ -249,8 +258,10 @@ end # generate # ############ -function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, - constraints::ChoiceMap) where {S,T,U,V,W,X,Y} +function generate( + gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, + constraints::ChoiceMap, + parameter_context::Dict) where {S,T,U,V,W,X,Y} (root_production_input::U, root_idx::Int) = args production_traces = PersistentHashMap{Int,S}() aggregation_traces = PersistentHashMap{Int,T}() @@ -268,7 +279,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, delete!(prod_to_visit, cur) input = get_production_input(gen_fn, cur, production_traces, root_idx, root_production_input) subconstraints = get_production_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.production_kern, (input,), subconstraints) + (subtrace, subweight) = generate(gen_fn.production_kern, (input,), subconstraints, parameter_context) score += get_score(subtrace) production_traces = assoc(production_traces, cur, subtrace) weight += subweight @@ -289,7 +300,7 @@ function generate(gen_fn::Recurse{S,T,U,V,W,X,Y}, args::Tuple{U,Int}, local input::Tuple{V,Vector{W}} input = get_aggregation_input(gen_fn, cur, production_traces, aggregation_traces) subconstraints = get_aggregation_constraints(constraints, cur) - (subtrace, subweight) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, subweight) = generate(gen_fn.aggregation_kern, input, subconstraints, parameter_context) score += get_score(subtrace) aggregation_traces = assoc(aggregation_traces, cur, subtrace) weight += subweight @@ -567,7 +578,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, else # the node does not exist already (and none of its children exist either) - (subtrace, ) = generate(gen_fn.production_kern, input, subconstraints) + (subtrace, ) = generate(gen_fn.production_kern, input, subconstraints, parameter_context) # update trace, weight, and score production_traces = assoc(production_traces, cur, subtrace) @@ -649,7 +660,7 @@ function update(trace::RecurseTrace{S,T,U,V,W,X,Y}, # if the node does not exist (but its children do, since we created them already) else - (subtrace, _) = generate(gen_fn.aggregation_kern, input, subconstraints) + (subtrace, _) = generate(gen_fn.aggregation_kern, input, subconstraints, parameter_context) # update trace, weight, and score aggregation_traces = assoc(aggregation_traces, cur, subtrace) diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index 4371eb8a4..187b48d00 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -4,23 +4,32 @@ mutable struct SwitchAssessState{T} SwitchAssessState{T}(weight::Float64) where T = new{T}(weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - choices::ChoiceMap, - state::SwitchAssessState{T}) where {C, N, K, T} - (weight, retval) = assess(getindex(gen_fn.branches, index), args, choices) +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + choices::ChoiceMap, + state::SwitchAssessState{T}, + parameter_context) where {C, N, K, T} + (weight, retval) = assess( + getindex(gen_fn.branches, index), args, choices, parameter_context) state.weight = weight state.retval = retval end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + choices::ChoiceMap, state::SwitchAssessState{T}, + parameter_context) where {C, N, K, T} + return process!( + gen_fn, getindex(gen_fn.cases, index), args, choices, state, parameter_context) +end -function assess(gen_fn::Switch{C, N, K, T}, - args::Tuple, - choices::ChoiceMap) where {C, N, K, T} +function assess( + gen_fn::Switch{C, N, K, T}, args::Tuple, + choices::ChoiceMap, + parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchAssessState{T}(0.0) - process!(gen_fn, index, args[2 : end], choices, state) - return state.weight, state.retval + process!(gen_fn, index, args[2 : end], choices, state, parameter_context) + return (state.weight, state.retval) end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index bd03f632e..eeed2c5d4 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -8,27 +8,38 @@ mutable struct SwitchGenerateState{T} SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - choices::ChoiceMap, - state::SwitchGenerateState{T}) where {C, N, K, T} +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + choices::ChoiceMap, + state::SwitchGenerateState{T}, + parameter_context) where {C, N, K, T} - (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) + (subtrace, weight) = generate( + getindex(gen_fn.branches, index), args, choices, parameter_context) state.index = index state.subtrace = subtrace state.weight += weight state.retval = get_retval(subtrace) end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + choices::ChoiceMap, state::SwitchGenerateState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state, parameter_context) +end -function generate(gen_fn::Switch{C, N, K, T}, - args::Tuple, - choices::ChoiceMap) where {C, N, K, T} +function generate( + gen_fn::Switch{C, N, K, T}, + args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) - process!(gen_fn, index, args[2 : end], choices, state) - return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight + process!(gen_fn, index, args[2 : end], choices, state, parameter_context) + trace = SwitchTrace{T}( + gen_fn, state.index, state.subtrace, state.retval, + args[2 : end], state.score, state.noise) + return (trace, state.weight) end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index b4df1d97f..abd38a04e 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,25 +5,32 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - state::SwitchProposeState{T}) where {C, N, K, T} - - (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + state::SwitchProposeState{T}, + parameter_context) where {C, N, K, T} + (submap, weight, retval) = propose( + getindex(gen_fn.branches, index), args, parameter_context) state.choices = submap state.weight += weight state.retval = retval end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, + state::SwitchProposeState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, state, parameter_context) +end -function propose(gen_fn::Switch{C, N, K, T}, - args::Tuple) where {C, N, K, T} +function propose( + gen_fn::Switch{C, N, K, T}, args::Tuple, + parameter_context::Dict) where {C, N, K, T} index = args[1] choices = choicemap() state = SwitchProposeState{T}(choices, 0.0) - process!(gen_fn, index, args[2:end], state) - return state.choices, state.weight, state.retval + process!(gen_fn, index, args[2:end], state, parameter_context) + return (state.choices, state.weight, state.retval) end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index fc4b3b02a..52683b78f 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -7,12 +7,13 @@ mutable struct SwitchSimulateState{T} SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) end -function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - state::SwitchSimulateState{T}) where {C, N, K, T} +function process!( + gen_fn::Switch{C, N, K, T}, + index::Int, args::Tuple, + state::SwitchSimulateState{T}, + parameter_context) where {C, N, K, T} local retval::T - subtrace = simulate(getindex(gen_fn.branches, index), args) + subtrace = simulate(getindex(gen_fn.branches, index), args, parameter_context) state.index = index state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace @@ -20,13 +21,19 @@ function process!(gen_fn::Switch{C, N, K, T}, state.retval = get_retval(subtrace) end -@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) +@inline function process!( + gen_fn::Switch{C, N, K, T}, index::C, + args::Tuple, state::SwitchSimulateState{T}, + parameter_context) where {C, N, K, T} + return process!(gen_fn, getindex(gen_fn.cases, index), args, state, parameter_context) +end -function simulate(gen_fn::Switch{C, N, K, T}, - args::Tuple) where {C, N, K, T} +function simulate( + gen_fn::Switch{C, N, K, T}, + args::Tuple, parameter_context::Dict) where {C, N, K, T} index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) - process!(gen_fn, index, args[2 : end], state) + process!(gen_fn, index, args[2 : end], state, parameter_context) return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 821143448..1dfc71fe2 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -19,6 +19,14 @@ has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_f end accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) +function get_parameters(gen_fn::Switch, parameter_context) + parameter_stores_to_ids = Dict{Any,Vector}() + for branch_gen_fn in gen_fn.branches + merge!(parameter_stores_to_ids, get_parameters(branch_gen_fn.production_kernel, parameter_context)) + end + return parameter_stores_to_ids +end + function (gen_fn::Switch)(index::Int, args...) (_, _, retval) = propose(gen_fn, (index, args...)) retval diff --git a/src/modeling_library/unfold/assess.jl b/src/modeling_library/unfold/assess.jl index 4199f77da..727805af2 100644 --- a/src/modeling_library/unfold/assess.jl +++ b/src/modeling_library/unfold/assess.jl @@ -4,24 +4,28 @@ mutable struct UnfoldAssessState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, - key::Int, state::UnfoldAssessState{T}) where {T,U} +function process_new!( + gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, + key::Int, state::UnfoldAssessState{T}, + parameter_context) where {T,U} local new_state::T kernel_args = (key, state.state, params...) submap = get_submap(choices, key) - (weight, new_state) = assess(gen_fn.kernel, kernel_args, submap) + (weight, new_state) = assess(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.retvals[key] = new_state state.state = new_state end -function assess(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function assess( + gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldAssessState{T}(0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, choices, key, state) + process_new!(gen_fn, params, choices, key, state, parameter_context) end - (state.weight, PersistentVector{T}(state.retvals)) + return (state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/generate.jl b/src/modeling_library/unfold/generate.jl index 3ef9a78b9..d20ab1844 100644 --- a/src/modeling_library/unfold/generate.jl +++ b/src/modeling_library/unfold/generate.jl @@ -8,13 +8,15 @@ mutable struct UnfoldGenerateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, - key::Int, state::UnfoldGenerateState{T,U}) where {T,U} +function process!( + gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, + key::Int, state::UnfoldGenerateState{T,U}, + parameter_context) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) submap = get_submap(choices, key) - (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap) + (subtrace, weight) = generate(gen_fn.kernel, kernel_args, submap, parameter_context) state.weight += weight state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) @@ -22,20 +24,22 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, choices::ChoiceMap, state.subtraces[key] = subtrace new_state = get_retval(subtrace) state.state = new_state - state.retval[key] = new_state + return state.retval[key] = new_state end -function generate(gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap) where {T,U} +function generate( + gen_fn::Unfold{T,U}, args::Tuple, choices::ChoiceMap, + parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldGenerateState{T,U}(0., 0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, choices, key, state) + process!(gen_fn, params, choices, key, state, parameter_context) end trace = VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), args, state.score, state.noise, len, state.num_nonempty) - (trace, state.weight) + return (trace, state.weight) end diff --git a/src/modeling_library/unfold/propose.jl b/src/modeling_library/unfold/propose.jl index 8863fbd4a..411467afd 100644 --- a/src/modeling_library/unfold/propose.jl +++ b/src/modeling_library/unfold/propose.jl @@ -5,8 +5,10 @@ mutable struct UnfoldProposeState{T} state::T end -function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, - state::UnfoldProposeState{T}) where {T,U} +function process_new!( + gen_fn::Unfold{T,U}, params::Tuple, key::Int, + state::UnfoldProposeState{T}, + parameter_context) where {T,U} local new_state::T kernel_args = (key, state.state, params...) (submap, weight, new_state) = propose(gen_fn.kernel, kernel_args) @@ -16,14 +18,15 @@ function process_new!(gen_fn::Unfold{T,U}, params::Tuple, key::Int, state.state = new_state end -function propose(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function propose( + gen_fn::Unfold{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] choices = choicemap() state = UnfoldProposeState{T}(choices, 0., Vector{T}(undef,len), init_state) for key=1:len - process_new!(gen_fn, params, key, state) + process_new!(gen_fn, params, key, state, parameter_context) end - (state.choices, state.weight, PersistentVector{T}(state.retvals)) + return (state.choices, state.weight, PersistentVector{T}(state.retvals)) end diff --git a/src/modeling_library/unfold/simulate.jl b/src/modeling_library/unfold/simulate.jl index e161e64a2..6eea21da5 100644 --- a/src/modeling_library/unfold/simulate.jl +++ b/src/modeling_library/unfold/simulate.jl @@ -7,12 +7,14 @@ mutable struct UnfoldSimulateState{T,U} state::T end -function process!(gen_fn::Unfold{T,U}, params::Tuple, - key::Int, state::UnfoldSimulateState{T,U}) where {T,U} +function process!( + gen_fn::Unfold{T,U}, params::Tuple, + key::Int, state::UnfoldSimulateState{T,U}, + parameter_context) where {T,U} local subtrace::U local new_state::T kernel_args = (key, state.state, params...) - subtrace = simulate(gen_fn.kernel, kernel_args) + subtrace = simulate(gen_fn.kernel, kernel_args, parameter_context) state.noise += project(subtrace, EmptySelection()) state.num_nonempty += (isempty(get_choices(subtrace)) ? 0 : 1) state.score += get_score(subtrace) @@ -22,14 +24,14 @@ function process!(gen_fn::Unfold{T,U}, params::Tuple, state.retval[key] = new_state end -function simulate(gen_fn::Unfold{T,U}, args::Tuple) where {T,U} +function simulate(gen_fn::Unfold{T,U}, args::Tuple, parameter_context::Dict) where {T,U} len = args[1] init_state = args[2] params = args[3:end] state = UnfoldSimulateState{T,U}(0., 0., Vector{U}(undef,len), Vector{T}(undef,len), 0, init_state) for key=1:len - process!(gen_fn, params, key, state) + process!(gen_fn, params, key, state, parameter_context) end VectorTrace{UnfoldType,T,U}(gen_fn, PersistentVector{U}(state.subtraces), PersistentVector{T}(state.retval), diff --git a/src/modeling_library/unfold/unfold.jl b/src/modeling_library/unfold/unfold.jl index 44238e3b7..45f240cd5 100644 --- a/src/modeling_library/unfold/unfold.jl +++ b/src/modeling_library/unfold/unfold.jl @@ -42,6 +42,9 @@ end # TODO accepts_output_grad(gen_fn::Unfold) = false +get_parameters(gen_fn::Unfold, parameter_context) = get_parameters(gen_fn.kernel, parameter_context) + + function (gen_fn::Unfold)(args...) (_, _, retval) = propose(gen_fn, args) retval diff --git a/src/optimization.jl b/src/optimization.jl index 8617f6c27..b1b2430f4 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -1,107 +1,565 @@ +import Parameters + +# TODO we should modify the semantics of the log probability contribution to +# the gradient so that everything is gradient descent instead of ascent. this +# will also fix the misnomer names +# +# TODO add tests specifically for JuliaParameterStore etc. + +export in_place_add! + +export FixedStepGradientDescent +export DecayStepGradientDescent +export init_optimizer +export apply_update! + +export JuliaParameterStore +export init_parameter! +export increment_gradient! +export reset_gradient! +export get_parameter_value +export get_gradient +export JULIA_PARAMETER_STORE_KEY + +export default_julia_parameter_store +export default_parameter_context + +############################ +# optimizer specifications # +############################ + """ - state = init_update_state(conf, gen_fn::GenerativeFunction, param_list::Vector) + conf = FixedStepGradientDescent(step_size) -Get the initial state for a parameter update to the given parameters of the given generative function. +Configuration for stochastic gradient descent update with fixed step size. +""" +Parameters.@with_kw struct FixedStepGradientDescent + step_size::Float64 +end -`param_list` is a vector of references to parameters of `gen_fn`. -`conf` configures the update. """ -function init_update_state end + conf = GradientDescent(step_size_init, step_size_beta) +Configuration for stochastic gradient descent update with step size given by `(t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t)` where `t` is the iteration number. """ - apply_update!(state) +Parameters.@with_kw struct DecayStepGradientDescent + step_size_init::Float64 + step_size_beta::Float64 +end + +# TODO add gradient descent with momentum +# TODO add update + +################# +# in_place_add! # +################# + +function in_place_add! end + +function in_place_add!(value::Array, increment) + @simd for i in 1:length(value) + value[i] += increment[i] + end + return value +end + +# this exists so user can use the same function on scalars and arrays +function in_place_add!(value::Real, increment::Real) + return value + increment +end + + +########################### +# thread-safe accumulator # +########################### + +mutable struct Accumulator{T<:Union{Real,Array}} + value::T + lock::ReentrantLock +end + +Accumulator(value) = Accumulator(value, ReentrantLock()) + +# NOTE: not thread-safe because it may return a reference to the Array +get_value(accum::Accumulator) = accum.value + +function fill_with_zeros!(accum::Accumulator{T}) where {T <: Real} + lock(accum.lock) + try + accum.value = zero(T) + finally + unlock(accum.lock) + end + return accum +end + +function fill_with_zeros!(accum::Accumulator{<:Array{T}}) where {T} + lock(accum.lock) + try + fill!(accum.value, zero(T)) + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::Accumulator{<:Real}, increment::Real, scale_factor::Real) + lock(accum.lock) + try + accum.value = accum.value + increment * scale_factor + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::Accumulator{<:Real}, increment::Real) + lock(accum.lock) + try + accum.value = accum.value + increment + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::Accumulator{<:Array}, increment, scale_factor::Real) + lock(accum.lock) + try + @simd for i in 1:length(accum.value) + accum.value[i] += increment[i] * scale_factor + end + finally + unlock(accum.lock) + end + return accum +end + +function in_place_add!(accum::Accumulator{<:Array}, increment) + lock(accum.lock) + try + @simd for i in 1:length(accum.value) + accum.value[i] += increment[i] + end + finally + unlock(accum.lock) + end + return accum +end + + + + +################################### +# parameter stores and optimizers # +################################### -Apply one parameter update, mutating the values of the trainable parameters, and possibly also the given state. """ -function apply_update! end + optimizer = init_optimizer( + conf, parameter_ids::Vector, + store=default_julia_parameter_store) +Initialize an iterative gradient-based optimizer that mutates a single parameter store. + +The first argument defines the mathematical behavior of the update, the second argument defines the set of parameters to which the update should be applied at each iteration, and the third argument gives the location of the parameter values and their gradient accumulators. + +Add support for new parameter store types or new optimization configurations by (i) defining a new Julia type for an optimizer that applies the configuration type to the parameter store type, and (ii) implementing this method so that it returns an instance of your new type, and (iii) implement `apply_update!` for your new type. """ - update = ParamUpdate(conf, param_lists...) +function init_optimizer(conf, parameter_ids, store=default_julia_parameter_store) + error("Not implemented") +end -Return an update configured by `conf` that applies to set of parameters defined by `param_lists`. +""" + apply_update!(optimizer) -Each element in `param_lists` value is is pair of a generative function and a vector of its parameter references. +Apply one iteration of a gradient-based optimization update. -**Example**. To construct an update that applies a gradient descent update to the parameters `:a` and `:b` of generative function `foo` and the parameter `:theta` of generative function `:bar`: +Extend this method to add support for new parameter store types or new optimization configurations. +""" +function apply_update!(optimizer) + error("Not implemented") +end -```julia -update = ParamUpdate(GradientDescent(0.001, 100), foo => [:a, :b], bar => [:theta]) -``` +struct CompositeOptimizer + conf::Any + optimizers::Dict{Any,Any} + function CompositeOptimizer(conf, parameter_stores_to_ids) + optimizers = Dict{Any,Any}() + for (store, parameter_ids) in parameter_stores_to_ids + optimizers[store] = init_optimizer(conf, collect(parameter_ids), store) + end + new(conf, optimizers) + end +end ------------------------------------------------------------------------------------------- -Syntactic sugar for the constructor form above. +""" - update = ParamUpdate(conf, gen_fn::GenerativeFunction) + optimizer = init_optimizer(conf, parameter_stores_to_ids::Dict{Any,Vector}) -Return an update configured by `conf` that applies to all trainable parameters owned by the given generative function. +Construct a special type of optimizer that updates parameters in multiple parameter stores. -Note that trainable parameters not owned by the given generative function will not be updated, even if they are used during execution of the function. +The second argument defines the set of parameters to which the update should be applied at each iteration, +The parameters are given in a map from parameter store to a vector of IDs of parameters within that parameter store. + +NOTE: You do _not_ need to extend this method to extend support for new parameter store types or new optimization configurations. +""" +function init_optimizer(conf, parameter_stores_to_ids::Dict) + return CompositeOptimizer(conf, parameter_stores_to_ids) +end + +""" + optimizer = init_optimizer( + conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) + +Convenience method that constructs an optimizer that updates all parameters used by the given generative function, even when the parameters exist in multiple parameter stores. +""" +function init_optimizer( + conf, gen_fn::GenerativeFunction; parameter_context=default_parameter_context) + return CompositeOptimizer(conf, get_parameters(gen_fn, parameter_context)) +end + + +function apply_update!(composite_opt::CompositeOptimizer) + for opt in values(composite_opt.optimizers) + apply_update!(opt) + end + return nothing +end -**Example**. If generative function `foo` has parameters `:a` and `:b`, to construct an update that applies a gradient descent update to the parameters `:a` and `:b`: +######### +# Julia # +######### + +struct JuliaParameterStore + values::Dict{GenerativeFunction,Dict{Symbol,Any}} + gradient_accumulators::Dict{GenerativeFunction,Dict{Symbol,Accumulator}} +end + +""" + + store = JuliaParameterStore() + +Construct a parameter store stores the state of parameters in the memory of the Julia runtime as Julia values. + +There is a global Julia parameter store automatically created and named `Gen.default_julia_parameter_store`. + +Gradient accumulation is thread-safe (see [`increment_gradient!`](@ref)). +""" +function JuliaParameterStore() + return JuliaParameterStore( + Dict{GenerativeFunction,Dict{Symbol,Any}}(), + Dict{GenerativeFunction,Dict{Symbol,Accumulator}}()) +end + +function get_local_parameters(store::JuliaParameterStore, gen_fn) + if !haskey(store.values, gen_fn) + return Dict{Symbol,Any}() + else + return store.values[gen_fn] + end +end + +# for looking up in a parameter context when tracing (simulate, generate) +# once a trace is generated, it is bound to use a particular store +""" + JULIA_PARAMETER_STORE_KEY + +If a parameter context contains a value for this key, then the value is a `JuliaParameterStore`. +""" +const JULIA_PARAMETER_STORE_KEY = :julia_parameter_store + +function get_julia_store(context::Dict) + return context[JULIA_PARAMETER_STORE_KEY]::JuliaParameterStore +end + +""" + default_julia_parameter_store::JuliaParameterStore + +The default global Julia parameter store. +""" +const default_julia_parameter_store = JuliaParameterStore() + + +""" + init_parameter!( + id::Tuple{GenerativeFunction,Symbol}, value, + store::JuliaParameterStore=default_julia_parameter_store) + +Initialize the the value of a named trainable parameter of a generative function. + +Also initializes the gradient accumulator for that parameter to `zero(value)`. + +Example: ```julia -update = ParamUpdate(GradientDescent(0.001, 100), foo) +init_parameter!((foo, :theta), 0.6) ``` + +Not thread-safe. """ -struct ParamUpdate - states::Dict{GenerativeFunction,Any} - conf::Any - function ParamUpdate(conf, param_lists...) - states = Dict{GenerativeFunction,Any}() - for (gen_fn, param_list) in param_lists - states[gen_fn] = init_update_state(conf, gen_fn, param_list) - end - new(states, conf) +function init_parameter!( + id::Tuple{GenerativeFunction,Symbol}, value, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + if !haskey(store.values, gen_fn) + store.values[gen_fn] = Dict{Symbol,Any}() end - function ParamUpdate(conf, gen_fn::GenerativeFunction) - param_lists = Dict(gen_fn => collect(get_params(gen_fn))) - ParamUpdate(conf, param_lists...) + store.values[gen_fn][name] = value + if !haskey(store.gradient_accumulators, gen_fn) + store.gradient_accumulators[gen_fn] = Dict{Symbol,Any}() end + store.gradient_accumulators[gen_fn][name] = Accumulator(zero(value)) + return nothing end +""" + reset_gradient!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) +Reset the gradient accumulator for a trainable parameter. + +Not thread-safe. """ - apply!(update::ParamUpdate) +function reset_gradient!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + local value::Any + try + value = store.values[gen_fn][name] + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() + end + if !haskey(store.gradient_accumulators, gen_fn) + store.gradient_accumulators[gen_fn] = Dict{Symbol,Any}() + end + if haskey(store.gradient_accumulators[gen_fn], name) + fill_with_zeros!(store.gradient_accumulators[gen_fn][name]) + else + store.gradient_accumulators[gen_fn][name] = Accumulator(zero(value)) + end + return nothing +end -Perform one step of the update. """ -function apply!(update::ParamUpdate) - for (_, state) in update.states - apply_update!(state) + increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, scale_factor::Real, + store::JuliaParameterStore=default_julia_parameter_store) + +Increment the gradient accumulator for a parameter. + +The increment is scaled by the given scale_factor. + +Thread-safe (multiple threads can increment the gradient of the same parameter concurrently). +""" +function increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, scale_factor, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + try + in_place_add!(store.gradient_accumulators[gen_fn][name], increment, scale_factor) + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() end - nothing + return nothing end """ - conf = FixedStepGradientDescent(step_size) + increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, + store::JuliaParameterStore=default_julia_parameter_store) -Configuration for stochastic gradient descent update with fixed step size. +Increment the gradient accumulator for a parameter. + +Thread-safe (multiple threads can increment the gradient of the same parameter concurrently). """ -struct FixedStepGradientDescent - step_size::Float64 +function increment_gradient!( + id::Tuple{GenerativeFunction,Symbol}, increment, + store::JuliaParameterStore=default_julia_parameter_store) + accumulator = get_gradient_accumulator(id, store) + in_place_add!(accumulator, increment) + return nothing end + """ - conf = GradientDescent(step_size_init, step_size_beta) + accum::Accumulator = get_gradient_accumulator!( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) -Configuration for stochastic gradient descent update with step size given by `(t::Int) -> step_size_init * (step_size_beta + 1) / (step_size_beta + t)` where `t` is the iteration number. +Return the gradient accumulator for a parameter. + +Not thread-safe. """ -struct GradientDescent - step_size_init::Float64 - step_size_beta::Float64 +function get_gradient_accumulator( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + try + return store.gradient_accumulators[gen_fn][name] + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() + end +end + +""" + value::Union{Real,Array} = get_parameter_value( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + +Get the current value of a parameter. + +Not thread-safe. +""" +function get_parameter_value( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + try + return store.values[gen_fn][name] + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() + end +end + +""" + set_parameter_value!( + id::Tuple{GenerativeFunction,Symbol}, value::Union{Real,Array}, + store::JuliaParameterStore=default_julia_parameter_store) + +Set the value of a parameter. + +Not thread-safe. +""" +function set_parameter_value!( + id::Tuple{GenerativeFunction,Symbol}, value::Union{Real,Array}, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + try + store.values[gen_fn][name] = value + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() + end + return nothing end """ - conf = ADAM(learning_rate, beta1, beta2, epsilon) + gradient::Union{Real,Array} = get_gradient( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + +Get the current value of the gradient accumulator for a parameter. -Configuration for ADAM update. +Not thread-safe. """ -struct ADAM - learning_rate::Float64 - beta1::Float64 - beta2::Float64 - epsilon::Float64 +function get_gradient( + id::Tuple{GenerativeFunction,Symbol}, + store::JuliaParameterStore=default_julia_parameter_store) + (gen_fn, name) = id + try + return get_value(store.gradient_accumulators[gen_fn][name]) + catch KeyError + @error "parameter $name of $gen_fn was not initialized" + rethrow() + end +end + +##################################################### +# Optimizer implementations for JuliaParameterStore # +##################################################### + +mutable struct FixedStepGradientDescentJulia + conf::FixedStepGradientDescent + store::JuliaParameterStore + parameters::Vector +end + +function init_optimizer( + conf::FixedStepGradientDescent, + parameters::Vector, + store::JuliaParameterStore=default_julia_parameter_store) + return FixedStepGradientDescentJulia(conf, store, parameters) +end + +function apply_update!(opt::FixedStepGradientDescentJulia) + for parameter_id::Tuple{GenerativeFunction,Symbol} in opt.parameters + value = get_parameter_value(parameter_id, opt.store) + gradient = get_gradient(parameter_id, opt.store) + new_value = in_place_add!(value, gradient * opt.conf.step_size) + set_parameter_value!(parameter_id, new_value, opt.store) + reset_gradient!(parameter_id, opt.store) + end +end + +mutable struct DecayStepGradientDescentJulia + conf::DecayStepGradientDescent + store::JuliaParameterStore + parameters::Vector + t::Int +end + +function init_optimizer( + conf::DecayStepGradientDescent, + parameters::Vector, + store::JuliaParameterStore=default_julia_parameter_store) + return DecayStepGradientDescentJulia(conf, store, parameters, 1) +end + +function apply_update!(opt::DecayStepGradientDescentJulia) + step_size_init = opt.conf.step_size_init + step_size_beta = opt.conf.step_size_beta + step_size = step_size_init * (step_size_beta + 1) / (step_size_beta + opt.t) + for parameter_id::Tuple{GenerativeFunction,Symbol} in opt.parameters + value = get_parameter_value(parameter_id, opt.store) + gradient = get_gradient(parameter_id, opt.store) + new_value = in_place_add!(value, gradient * step_size) + set_parameter_value!(parameter_id, new_value, opt.store) + reset_gradient!(parameter_id, opt.store) + end + opt.t += 1 +end + +# TODO implement other optimizers (ADAM, etc.) + +############################# +# default parameter context # +############################# + + +""" + default_parameter_context::Dict + +The default global parameter context, which is initialized to contain the mapping: + + JULIA_PARAMETER_STORE_KEY => Gen.default_julia_parameter_store +""" +const default_parameter_context = Dict{Symbol,Any}( + JULIA_PARAMETER_STORE_KEY => default_julia_parameter_store) + +function get_parameters(gen_fn::GenerativeFunction) + return get_parameters(gen_fn, default_parameter_context) +end + +function simulate(gen_fn::GenerativeFunction, args::Tuple) + return simulate(gen_fn, args, default_parameter_context) +end + +function generate(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + return generate(gen_fn, args, choices, default_parameter_context) +end + +function generate(gen_fn::GenerativeFunction, args::Tuple) + return generate(gen_fn, args, EmptyChoiceMap(), default_parameter_context) +end + +function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) + return assess(gen_fn, args, choices, default_parameter_context) end -export ParameterSet, ParamUpdate, apply! -export FixedStepGradientDescent, GradientDescent, ADAM +propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 96f63da0c..b42613a9e 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -1,3 +1,5 @@ +# TODO this code needs to be simplified + struct BackpropTraceMode end struct BackpropParamsMode end @@ -19,40 +21,43 @@ maybe_tracked_value_var(node::JuliaNode) = Symbol("$(maybe_tracked_value_prefix) const maybe_tracked_arg_prefix = gensym("maybe_tracked_arg") maybe_tracked_arg_var(node::JuliaNode, i::Int) = Symbol("$(maybe_tracked_arg_prefix)_$(node.name)_$i") -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::TrainableParameterNode) - # TODO: only need to mark it if we are doing backprop params - push!(fwd_marked, node) +function forward_marking!( + selected_choices, selected_calls, fwd_marked, + node::TrainableParameterNode, mode) + if mode == BackpropParamsMode() + push!(fwd_marked, node) + end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::ArgumentNode, mode) if node.compute_grad push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::JuliaNode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::JuliaNode, mode) if any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::RandomChoiceNode, mode) if node in selected_choices push!(fwd_marked, node) end end -function fwd_pass!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode) +function forward_marking!(selected_choices, selected_calls, fwd_marked, node::GenerativeFunctionCallNode, mode) if node in selected_calls || any(input_node in fwd_marked for input_node in node.inputs) push!(fwd_marked, node) end end -function back_pass!(back_marked, node::TrainableParameterNode) end +function backward_marking!(back_marked, node::TrainableParameterNode) end -function back_pass!(back_marked, node::ArgumentNode) end +function backward_marking!(back_marked, node::ArgumentNode) end -function back_pass!(back_marked, node::JuliaNode) +function backward_marking!(back_marked, node::JuliaNode) if node in back_marked for input_node in node.inputs push!(back_marked, input_node) @@ -60,7 +65,7 @@ function back_pass!(back_marked, node::JuliaNode) end end -function back_pass!(back_marked, node::RandomChoiceNode) +function backward_marking!(back_marked, node::RandomChoiceNode) # the logpdf of every random choice is a SINK for input_node in node.inputs push!(back_marked, input_node) @@ -69,27 +74,33 @@ function back_pass!(back_marked, node::RandomChoiceNode) push!(back_marked, node) end -function back_pass!(back_marked, node::GenerativeFunctionCallNode) +function backward_marking!(back_marked, node::GenerativeFunctionCallNode) # the logpdf of every generative function call is a SINK - for input_node in node.inputs - push!(back_marked, input_node) + # (we could ask whether the generative function is deterministic or not + # as a perforance optimization, because only stochsatic generative functions + # actually have a non-trivial logpdf) + for (input_node, has_grad) in zip(node.inputs, has_argument_grads(node.generative_function)) + if has_grad + push!(back_marked, input_node) + end end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::TrainableParameterNode) if node in back_marked - push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace), - $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end if node in fwd_marked && node in back_marked # initialize gradient to zero - push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) + # NOTE: we are avoiding allocating a new gradient accumulator for this function + # instead, we are using the threadsafe gradient accumulator directly.. + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(get_gradient_accumulator))(trace, $(QuoteNode(node.name))))) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) if node in fwd_marked && node in back_marked # initialize gradient to zero @@ -97,9 +108,9 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::ArgumentNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) +function forward_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) - if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs) + if (node in fwd_marked) && (node in back_marked) # tracked forward execution tape = tape_var(node) @@ -121,27 +132,20 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::JuliaNode) # initialize gradient to zero push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) - else - # regular forward execution. + elseif node in back_marked - # we need the value for initializing gradient to zero (to get the type - # and e.g. shape), and for reference by other nodes during - # back_codegen! we could be more selective about which JuliaNodes need - # to be evalutaed, that is a performance optimization for the future + # regular forward execution. args = map((input_node) -> input_node.name, node.inputs) push!(stmts, :($(node.name) = $(QuoteNode(node.fn))($(args...)))) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) - +function forward_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) # every random choice is in back_marked, since it affects it logpdf, but # also possibly due to other downstream usage of the value @assert node in back_marked + push!(stmts, :($(node.name) = trace.$(get_value_fieldname(node)))) if node in fwd_marked # the only way we are fwd_marked is if this choice was selected @@ -152,52 +156,51 @@ function fwd_codegen!(stmts, fwd_marked, back_marked, node::RandomChoiceNode) end end -function fwd_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) - # for reference by other nodes during back_codegen! - # could performance optimize this away - subtrace_fieldname = get_subtrace_fieldname(node) - push!(stmts, :($(node.name) = get_retval(trace.$subtrace_fieldname))) +function forward_codegen!(stmts, fwd_marked, back_marked, node::GenerativeFunctionCallNode) + + if node in back_marked + # for reference by other nodes during backward_codegen! + subtrace_fieldname = get_subtrace_fieldname(node) + push!(stmts, :($(node.name) = $(QuoteNode(get_retval))(trace.$subtrace_fieldname))) + end # NOTE: we will still potentially run choice_gradients recursively on the generative function, # we just might not use its return value gradient. - if node in fwd_marked && node in back_marked + if (node in fwd_marked) && (node in back_marked) # we are fwd_marked if an input was fwd_marked, or if we were selected internally push!(stmts, :($(gradient_var(node)) = zero($(node.name)))) end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) + if mode == BackpropParamsMode() - # handle case when it is the return node - if node === ir.return_node && node in fwd_marked - @assert node in back_marked - push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) - push!(stmts, :($(gradient_var(node)) += retval_grad)) - end - - if node in fwd_marked && node in back_marked - cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref, - $(QuoteNode(node.name)))) - push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref, - $(QuoteNode(node.name)), - $cur_param_grad + $(gradient_var(node))))) + # handle case when it is the return node + if node === ir.return_node && node in fwd_marked + @assert node in back_marked + push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad, scale_factor))) + end end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::ArgumentNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::ArgumentNode, mode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad))) end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::JuliaNode, mode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::JuliaNode, mode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad))) end if node in back_marked && any(input_node in fwd_marked for input_node in node.inputs) @@ -209,19 +212,29 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node: for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked arg_maybe_tracked = maybe_tracked_arg_var(node, i) - push!(stmts, :($(gradient_var(input_node)) += $(QuoteNode(deriv))($arg_maybe_tracked))) + if isa(input_node, TrainableParameterNode) + @assert mode == BackpropParamsMode() + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked), scale_factor))) + else + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(input_node)), $(QuoteNode(deriv))($arg_maybe_tracked)))) + end end end end end -function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, - node::RandomChoiceNode, logpdf_grad::Symbol) +function backward_codegen_random_choice_to_inputs!( + stmts, ir, fwd_marked, back_marked, + node::RandomChoiceNode, logpdf_grad::Symbol, + mode) + # only evaluate the gradient of the logpdf if we need to if any(input_node in fwd_marked for input_node in node.inputs) || node in fwd_marked args = map((input_node) -> input_node.name, node.inputs) - push!(stmts, :($logpdf_grad = logpdf_grad($(node.dist), $(node.name), $(args...)))) + push!(stmts, :($logpdf_grad = $(QuoteNode(Gen.logpdf_grad))($(node.dist), $(node.name), $(args...)))) end # increment gradients of input nodes that are in fwd_marked @@ -231,46 +244,57 @@ function back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marke if !has_argument_grads(node.dist)[i] error("Distribution $(node.dist) does not have logpdf gradient for argument $i") end - push!(stmts, :($(gradient_var(input_node)) += $logpdf_grad[$(QuoteNode(i+1))])) + input_node_grad = gradient_var(input_node) + increment = :($logpdf_grad[$(QuoteNode(i+1))]) + if isa(input_node, TrainableParameterNode) && mode == BackpropParamsMode() + push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment, scale_factor))) + else + push!(stmts, :($input_node_grad = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment))) + end end end # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad))) end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropTraceMode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, + node::RandomChoiceNode, mode::BackpropTraceMode) logpdf_grad = gensym("logpdf_grad") # backpropagate to the inputs - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + backward_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad, mode) # backpropagate to the value (if it was selected) if node in fwd_marked if !has_output_grad(node.dist) error("Distribution $dist does not logpdf gradient for its output value") end - push!(stmts, :($(gradient_var(node)) += $logpdf_grad[1])) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), $logpdf_grad[1]))) end end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, - node::RandomChoiceNode, ::BackpropParamsMode) +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, + node::RandomChoiceNode, mode::BackpropParamsMode) logpdf_grad = gensym("logpdf_grad") - back_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad) + backward_codegen_random_choice_to_inputs!(stmts, ir, fwd_marked, back_marked, node, logpdf_grad, mode) end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropTraceMode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad))) end if node in fwd_marked @@ -285,7 +309,7 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, push!(stmts, :($call_selection = EmptySelection())) end retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = choice_gradients( + push!(stmts, :(($input_grads, $value_trie, $gradient_trie) = $(QuoteNode(choice_gradients))( trace.$subtrace_fieldname, $call_selection, $retval_grad))) end @@ -293,34 +317,46 @@ function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, for (i, input_node) in enumerate(node.inputs) if input_node in fwd_marked @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + input_node_grad = gradient_var(input_node) + increment = :($input_grads[$(QuoteNode(i))]) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment))) end end # NOTE: the value_trie and gradient_trie are dealt with later end -function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, +function backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::GenerativeFunctionCallNode, mode::BackpropParamsMode) # handle case when it is the return node if node === ir.return_node && node in fwd_marked @assert node in back_marked - push!(stmts, :($(gradient_var(node)) += retval_grad)) + push!(stmts, :($(gradient_var(node)) = $(QuoteNode(in_place_add!))( + $(gradient_var(node)), retval_grad))) end if node in fwd_marked input_grads = gensym("call_input_grads") subtrace_fieldname = get_subtrace_fieldname(node) retval_grad = node in back_marked ? gradient_var(node) : :(nothing) - push!(stmts, :($input_grads = accumulate_param_gradients!(trace.$subtrace_fieldname, $retval_grad))) + push!(stmts, :($input_grads = $(QuoteNode(accumulate_param_gradients!))(trace.$subtrace_fieldname, $retval_grad, scale_factor))) end # increment gradients of input nodes that are in fwd_marked - for (i, input_node) in enumerate(node.inputs) - if input_node in fwd_marked + for (i, (input_node, has_grad)) in enumerate(zip(node.inputs, has_argument_grads(node.generative_function))) + if input_node in fwd_marked && has_grad @assert input_node in back_marked # this ensured its gradient will have been initialized - push!(stmts, :($(gradient_var(input_node)) += $input_grads[$(QuoteNode(i))])) + input_node_grad = gradient_var(input_node) + increment = :($input_grads[$(QuoteNode(i))]) + if isa(input_node, TrainableParameterNode) + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment, scale_factor))) + else + push!(stmts, :($(gradient_var(input_node)) = $(QuoteNode(in_place_add!))( + $input_node_grad, $increment))) + end end end end @@ -364,15 +400,15 @@ function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) push!(selected_choices, node) end end - selected_choices + return selected_choices end function get_selected_calls(::EmptyAddressSchema, ::StaticIR) - Set{GenerativeFunctionCallNode}() + return Set{GenerativeFunctionCallNode}() end function get_selected_calls(::AllAddressSchema, ir::StaticIR) - Set{GenerativeFunctionCallNode}(ir.call_nodes) + return Set{GenerativeFunctionCallNode}(ir.call_nodes) end function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) @@ -383,7 +419,7 @@ function get_selected_calls(schema::StaticAddressSchema, ir::StaticIR) push!(selected_calls, node) end end - selected_calls + return selected_calls end function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, @@ -403,30 +439,30 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # forward marking pass fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + forward_marking!(selected_choices, selected_calls, fwd_marked, node, BackpropTraceMode()) end # backward marking pass back_marked = Set{StaticIRNode}() push!(back_marked, ir.return_node) for node in reverse(ir.nodes) - back_pass!(back_marked, node) + backward_marking!(back_marked, node) end stmts = Expr[] # unpack arguments from the trace arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = get_args(trace))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace))) # forward code-generation pass (initialize gradients to zero, create needed references) for node in ir.nodes - fwd_codegen!(stmts, fwd_marked, back_marked, node) + forward_codegen!(stmts, fwd_marked, back_marked, node) end # backward code-generation pass (increment gradients) for node in reverse(ir.nodes) - back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode()) + backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropTraceMode()) end # assemble value_trie and gradient_trie @@ -442,11 +478,11 @@ function codegen_choice_gradients(trace_type::Type{T}, selection_type::Type, # return values push!(stmts, :(return ($input_grads, $value_trie, $gradient_trie))) - Expr(:block, stmts...) + return Expr(:block, stmts...) end function codegen_accumulate_param_gradients!(trace_type::Type{T}, - retval_grad_type::Type) where {T<:StaticIRTrace} + retval_grad_type::Type, scale_factor_type) where {T<:StaticIRTrace} gen_fn_type = get_gen_fn_type(trace_type) ir = get_ir(gen_fn_type) @@ -458,33 +494,36 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, selected_calls = Set{GenerativeFunctionCallNode}( node for node in ir.nodes if isa(node, GenerativeFunctionCallNode)) - # forward marking pass + # forward marking pass (propagate forward from 'sources') fwd_marked = Set{StaticIRNode}() for node in ir.nodes - fwd_pass!(selected_choices, selected_calls, fwd_marked, node) + forward_marking!(selected_choices, selected_calls, fwd_marked, node, BackpropParamsMode()) end - # backward marking pass + # backward marking pass (propagate backwards from 'sinks') back_marked = Set{StaticIRNode}() push!(back_marked, ir.return_node) for node in reverse(ir.nodes) - back_pass!(back_marked, node) + backward_marking!(back_marked, node) end stmts = Expr[] # unpack arguments from the trace arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] - push!(stmts, :($(Expr(:tuple, arg_names...)) = get_args(trace))) + push!(stmts, :($(Expr(:tuple, arg_names...)) = $(QuoteNode(get_args))(trace))) - # forward code-generation pass (initialize gradients to zero, create needed references) + # forward code-generation pass + # any node that is backward-marked creates a variable for its current value + # any node that is forward-marked and backwards marked initializes a gradient variable for node in ir.nodes - fwd_codegen!(stmts, fwd_marked, back_marked, node) + forward_codegen!(stmts, fwd_marked, back_marked, node) end - # backward code-generation pass (increment gradients) + # backward code-generation pass + # any node that is forward-marked and backwards marked increments its gradient variable for node in reverse(ir.nodes) - back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode()) + backward_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node, BackpropParamsMode()) end # gradients with respect to inputs @@ -494,19 +533,21 @@ function codegen_accumulate_param_gradients!(trace_type::Type{T}, # return values push!(stmts, :(return $input_grads)) - Expr(:block, stmts...) + return Expr(:block, stmts...) end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :choice_gradients))(trace::T, selection::$(QuoteNode(Selection)), - retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) +@generated function $(GlobalRef(Gen, :choice_gradients))( + trace::T, selection::$(QuoteNode(Selection)), + retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} + return $(QuoteNode(codegen_choice_gradients))(trace, selection, retval_grad) end end) push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} - $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) +@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))( + trace::T, retval_grad, scale_factor) where {T<:$(QuoteNode(StaticIRTrace))} + return $(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad, scale_factor) end end) diff --git a/src/static_ir/dag.jl b/src/static_ir/dag.jl index c82658892..65c1ab7bc 100644 --- a/src/static_ir/dag.jl +++ b/src/static_ir/dag.jl @@ -204,6 +204,25 @@ function set_accepts_output_grad!(builder::StaticIRBuilder, value::Bool) builder.accepts_output_grad = value end +function get_parameters(ir::StaticIR, gen_fn::GenerativeFunction, parameter_context) + julia_store = get_julia_store(parameter_context) + parameters = Dict(julia_store => Set{Tuple{GenerativeFunction,Symbol}}()) + for param_node in ir.trainable_param_nodes + push!(parameters[julia_store], (gen_fn, param_node.name)) + end + for call_node in ir.call_nodes + callee_parameters = get_parameters(call_node.generative_function, parameter_context) + for (store, ids::Set) in callee_parameters + if haskey(parameters, store) + union!(parameters[store], ids) + else + parameters[store] = ids + end + end + end + return parameters +end + export StaticIR, StaticIRBuilder, build_ir export add_trainable_param_node! export add_argument_node! diff --git a/src/static_ir/generate.jl b/src/static_ir/generate.jl index 3557f19b1..808a7e722 100644 --- a/src/static_ir/generate.jl +++ b/src/static_ir/generate.jl @@ -6,7 +6,9 @@ end function process!(::StaticIRGenerateState, node, options) end function process!(state::StaticIRGenerateState, node::TrainableParameterNode, options) - push!(state.stmts, :($(node.name) = $(QuoteNode(get_param))(gen_fn, $(QuoteNode(node.name))))) + push!(state.stmts, :($(node.name) = $(QuoteNode(get_parameter_value))( + (gen_fn, $(QuoteNode(node.name))), + $parameter_store_fieldname))) end function process!(state::StaticIRGenerateState, node::ArgumentNode, options) @@ -84,6 +86,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, push!(stmts, :($total_noise_fieldname = 0.)) push!(stmts, :($weight = 0.)) push!(stmts, :($num_nonempty_fieldname = 0)) + push!(stmts, :($parameter_store_fieldname = $(QuoteNode(get_julia_store))(parameter_context))) # unpack arguments arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] @@ -99,7 +102,7 @@ function codegen_generate(gen_fn_type::Type{T}, args, push!(stmts, :($return_value_fieldname = $(ir.return_node.name))) # construct trace - push!(stmts, :($static_ir_gen_fn_ref = gen_fn)) + push!(stmts, :($static_ir_gen_fn_fieldname = gen_fn)) push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) # return trace and weight @@ -109,8 +112,10 @@ function codegen_generate(gen_fn_type::Type{T}, args, end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :generate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), - args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap))) +@generated function $(GlobalRef(Gen, :generate))( + gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), + args::$(QuoteNode(Tuple)), constraints::$(QuoteNode(ChoiceMap)), + parameter_context::Dict) $(QuoteNode(codegen_generate))(gen_fn, args, constraints) end end) diff --git a/src/static_ir/simulate.jl b/src/static_ir/simulate.jl index 48c215a3c..4b356653b 100644 --- a/src/static_ir/simulate.jl +++ b/src/static_ir/simulate.jl @@ -5,7 +5,9 @@ end function process!(::StaticIRSimulateState, node, options) end function process!(state::StaticIRSimulateState, node::TrainableParameterNode, options) - push!(state.stmts, :($(node.name) = $(QuoteNode(get_param))(gen_fn, $(QuoteNode(node.name))))) + push!(state.stmts, :($(node.name) = $(QuoteNode(get_parameter_value))( + (gen_fn, $(QuoteNode(node.name))), + $parameter_store_fieldname))) end function process!(state::StaticIRSimulateState, node::ArgumentNode, options) @@ -47,7 +49,9 @@ function process!(state::StaticIRSimulateState, node::GenerativeFunctionCallNode push!(state.stmts, :($total_noise_fieldname += $(GlobalRef(Gen, :project))($subtrace, $(GlobalRef(Gen, :EmptySelection))()))) end -function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenerativeFunction} +function codegen_simulate( + gen_fn_type::Type{T}, args, + parameter_context_type) where {T <: StaticIRGenerativeFunction} ir = get_ir(gen_fn_type) options = get_options(gen_fn_type) @@ -57,6 +61,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera push!(stmts, :($total_score_fieldname = 0.)) push!(stmts, :($total_noise_fieldname = 0.)) push!(stmts, :($num_nonempty_fieldname = 0)) + push!(stmts, :($parameter_store_fieldname = $(QuoteNode(get_julia_store))(parameter_context))) # unpack arguments arg_names = Symbol[arg_node.name for arg_node in ir.arg_nodes] @@ -73,7 +78,7 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera # construct trace trace_type = get_trace_type(gen_fn_type) - push!(stmts, :($static_ir_gen_fn_ref = gen_fn)) + push!(stmts, :($static_ir_gen_fn_fieldname = gen_fn)) push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) # return trace @@ -83,7 +88,9 @@ function codegen_simulate(gen_fn_type::Type{T}, args) where {T <: StaticIRGenera end push!(generated_functions, quote -@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), args::Tuple) - $(QuoteNode(codegen_simulate))(gen_fn, args) +@generated function $(GlobalRef(Gen, :simulate))( + gen_fn::$(QuoteNode(StaticIRGenerativeFunction)), + args::Tuple, parameter_context::Dict) + $(QuoteNode(codegen_simulate))(gen_fn, args, parameter_context) end end) diff --git a/src/static_ir/static_ir.jl b/src/static_ir/static_ir.jl index d5d7e237f..3e72a777b 100644 --- a/src/static_ir/static_ir.jl +++ b/src/static_ir/static_ir.jl @@ -46,10 +46,10 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati has_argument_grads = tuple(map((node) -> node.compute_grad, ir.arg_nodes)...) accepts_output_grad = ir.accepts_output_grad + show_str = "Gen SML generative function: $name" + gen_fn_defn = quote struct $gen_fn_type_name <: $(QuoteNode(StaticIRGenerativeFunction)){$return_type,$trace_type} - params_grad::Dict{Symbol,Any} - params::Dict{Symbol,Any} end (gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3] $(GlobalRef(Gen, :get_ir))(::$gen_fn_type_name) = $(QuoteNode(ir)) @@ -57,11 +57,18 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati $(GlobalRef(Gen, :get_trace_type))(::Type{$gen_fn_type_name}) = $trace_struct_name $(GlobalRef(Gen, :has_argument_grads))(::$gen_fn_type_name) = $(QuoteNode(has_argument_grads)) $(GlobalRef(Gen, :accepts_output_grad))(::$gen_fn_type_name) = $(QuoteNode(accepts_output_grad)) - $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))) + $(GlobalRef(Gen, :get_gen_fn))(trace::$trace_struct_name) = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_fieldname))) $(GlobalRef(Gen, :get_gen_fn_type))(::Type{$trace_struct_name}) = $gen_fn_type_name $(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options)) + function $(GlobalRef(Gen, :get_parameters))(gen_fn::$gen_fn_type_name, context) + return $(GlobalRef(Gen, :get_parameters))($(QuoteNode(ir)), gen_fn, context) + end + function Base.show(io::IO, ::MIME"text/plain", gen_fn::$gen_fn_type_name) + return $(QuoteNode(show_str)) + end + end - Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name, :(Dict{Symbol,Any}()), :(Dict{Symbol,Any}()))) + Expr(:block, trace_defns, gen_fn_defn, Expr(:call, gen_fn_type_name)) end include("print_ir.jl") diff --git a/src/static_ir/trace.jl b/src/static_ir/trace.jl index de2c84b30..4e5fe0539 100644 --- a/src/static_ir/trace.jl +++ b/src/static_ir/trace.jl @@ -37,6 +37,33 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap() abstract type StaticIRTrace <: Trace end +# fixed fields shared by all StaticIRTraces +const num_nonempty_fieldname = :num_nonempty +const total_score_fieldname = :score +const total_noise_fieldname = :noise +const return_value_fieldname = :retval +const parameter_store_fieldname = :parameter_store +const static_ir_gen_fn_fieldname = :gen_fn + +# other fields based on user-defined variable names are prefixed to avoid collisions +get_value_fieldname(node::ArgumentNode) = Symbol("#arg#_$(node.name)") +get_value_fieldname(node::RandomChoiceNode) = Symbol("#choice_value#_$(node.addr)") +get_value_fieldname(node::JuliaNode) = Symbol("#julia#_$(node.name)") +get_score_fieldname(node::RandomChoiceNode) = Symbol("#choice_score#_$(node.addr)") +get_subtrace_fieldname(node::GenerativeFunctionCallNode) = Symbol("#subtrace#_$(node.addr)") + +# getters + +function get_parameter_value(trace::StaticIRTrace, name) + parameter_id = (get_gen_fn(trace), name) + return get_parameter_value(parameter_id, trace.parameter_store) +end + +function get_gradient_accumulator(trace::StaticIRTrace, name) + parameter_id = (get_gen_fn(trace), name) + return get_gradient_accumulator(parameter_id, trace.parameter_store) +end + @inline function static_get_subtrace(trace::StaticIRTrace, addr) error("Not implemented") end @@ -52,36 +79,8 @@ end return Gen.static_get_subtrace(trace, Val(first))[rest] end -const arg_prefix = gensym("arg") -const choice_value_prefix = gensym("choice_value") -const choice_score_prefix = gensym("choice_score") -const subtrace_prefix = gensym("subtrace") -const julia_prefix = gensym("julia_prefix") - -function get_value_fieldname(node::ArgumentNode) - Symbol("$(arg_prefix)_$(node.name)") -end - -function get_value_fieldname(node::RandomChoiceNode) - Symbol("$(choice_value_prefix)_$(node.addr)") -end -function get_value_fieldname(node::JuliaNode) - Symbol("$(julia_prefix)_$(node.name)") -end -function get_score_fieldname(node::RandomChoiceNode) - Symbol("$(choice_score_prefix)_$(node.addr)") -end - -function get_subtrace_fieldname(node::GenerativeFunctionCallNode) - Symbol("$(subtrace_prefix)_$(node.addr)") -end - -const num_nonempty_fieldname = gensym("num_nonempty") -const total_score_fieldname = gensym("score") -const total_noise_fieldname = gensym("noise") -const return_value_fieldname = gensym("retval") struct TraceField fieldname::Symbol @@ -115,17 +114,17 @@ function get_trace_fields(ir::StaticIR, options::StaticIRGenerativeFunctionOptio push!(fields, TraceField(total_noise_fieldname, QuoteNode(Float64))) push!(fields, TraceField(num_nonempty_fieldname, QuoteNode(Int))) push!(fields, TraceField(return_value_fieldname, ir.return_node.typ)) + push!(fields, TraceField(parameter_store_fieldname, QuoteNode(JuliaParameterStore))) + push!(fields, TraceField(static_ir_gen_fn_fieldname, QuoteNode(Any))) return fields end -const static_ir_gen_fn_ref = gensym("gen_fn") - function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options::StaticIRGenerativeFunctionOptions) mutable = false fields = get_trace_fields(ir, options) field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields) Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)), - Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any)))) + Expr(:block, field_exprs...)) end function generate_isempty(trace_struct_name::Symbol) diff --git a/src/static_ir/update.jl b/src/static_ir/update.jl index 3a491c7a8..a47ece6a7 100644 --- a/src/static_ir/update.jl +++ b/src/static_ir/update.jl @@ -144,7 +144,7 @@ end function process_codegen!(stmts, ::ForwardPassState, back::BackwardPassState, node::TrainableParameterNode, ::AbstractUpdateMode, options) if node in back.marked - push!(stmts, :($(node.name) = $(QuoteNode(get_param))($(QuoteNode(get_gen_fn))(trace), $(QuoteNode(node.name))))) + push!(stmts, :($(node.name) = $(QuoteNode(get_parameter_value))(trace, $(QuoteNode(node.name))))) end end @@ -420,15 +420,16 @@ end function generate_new_trace!(stmts::Vector{Expr}, trace_type::Type, options) if options.track_diffs - # note that the generative function is the last field - constructor_args = map((name) -> Expr(:call, QuoteNode(strip_diff), name), - fieldnames(trace_type)[1:end-1]) - push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...), - $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref)))))) + # NOTE: applying 'strip_diff' to all of the fields, including fields that are not diffed, is a hack + constructor_args = map( + (name) -> Expr(:call, QuoteNode(strip_diff), name), + fieldnames(trace_type)) else - push!(stmts, :($static_ir_gen_fn_ref = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_ref))))) - push!(stmts, :($trace = $(QuoteNode(trace_type))($(fieldnames(trace_type)...)))) + constructor_args = fieldnames(trace_type) end + push!(stmts, :($static_ir_gen_fn_fieldname = $(Expr(:(.), :trace, QuoteNode(static_ir_gen_fn_fieldname))))) + push!(stmts, :($parameter_store_fieldname = $(Expr(:(.), :trace, QuoteNode(parameter_store_fieldname))))) + push!(stmts, :($trace = $(QuoteNode(trace_type))($(constructor_args...)))) end function generate_discard!(stmts::Vector{Expr}, diff --git a/test/dsl/dynamic_dsl.jl b/test/dsl/dynamic_dsl.jl index 25e3a50f4..f7458d6ee 100644 --- a/test/dsl/dynamic_dsl.jl +++ b/test/dsl/dynamic_dsl.jl @@ -285,6 +285,7 @@ end z = @trace(normal(mu_z + theta1, 1), :z) return z + mu_z end + register_parameters!(bar, [:theta1]) @gen (grad) function foo((grad)(mu_a::Float64)) @param theta2::Float64 @@ -294,9 +295,17 @@ end c = a * b * @trace(bar(a), :bar) return @trace(normal(c, 1), :out) + (theta2 * 3) end + register_parameters!(foo, [(bar, :theta1), :theta2]) - init_param!(bar, :theta1, 0.) - init_param!(foo, :theta2, 0.) + store_to_ids = Gen.get_parameters(foo, Gen.default_parameter_context) + @test length(store_to_ids) == 1 + ids = store_to_ids[Gen.default_julia_parameter_store] + @test length(ids) == 2 + @test (foo, :theta2) in ids + @test (bar, :theta1) in ids + + init_parameter!((bar, :theta1), 0.0) + init_parameter!((foo, :theta2), 0.0) function f(mu_a, a, b, z, out) lpdf = 0. @@ -358,8 +367,8 @@ end @test isapprox(mu_a_grad, finite_diff(f, (mu_a, a, b, z, out), 1, dx)) # check parameter gradient - theta1_grad = get_param_grad(bar, :theta1) - theta2_grad = get_param_grad(foo, :theta2) + theta1_grad = get_gradient((bar, :theta1)) + theta2_grad = get_gradient((foo, :theta2)) @test isapprox(theta1_grad, logpdf_grad(normal, z, a, 1)[2]) @test isapprox(theta2_grad, 3 * 2) @@ -371,17 +380,18 @@ end @param theta::Float64 return theta end - - init_param!(baz, :theta, 0.) + register_parameters!(baz, [:theta]) + init_parameter!((baz, :theta), 0.0) @gen (grad) function foo() return @trace(baz()) end + register_parameters!(foo, [(baz, :theta)]) (trace, _) = generate(foo, ()) retval_grad = 2. accumulate_param_gradients!(trace, retval_grad) - @test isapprox(get_param_grad(baz, :theta), retval_grad) + @test isapprox(get_gradient((baz, :theta)), retval_grad) end @testset "gradient descent with fixed step size" begin @@ -389,14 +399,18 @@ end @param theta::Float64 return theta end - init_param!(foo, :theta, 0.) + register_parameters!(foo, [:theta]) + + init_parameter!((foo, :theta), 0.0) + (trace, ) = generate(foo, ()) - accumulate_param_gradients!(trace, 1.) + accumulate_param_gradients!(trace, 1.0) + @test isapprox(get_gradient((foo, :theta)), 1.0) conf = FixedStepGradientDescent(0.001) - state = Gen.init_update_state(conf, foo, [:theta]) - Gen.apply_update!(state) - @test isapprox(get_param(foo, :theta), 0.001) - @test isapprox(get_param_grad(foo, :theta), 0.) + optimizer = init_optimizer(conf, [(foo, :theta)]) + apply_update!(optimizer) + @test isapprox(get_parameter_value((foo, :theta)), 0.001) + @test isapprox(get_gradient((foo, :theta)), 0.0) end @testset "gradient descent with shrinking step size" begin @@ -404,14 +418,18 @@ end @param theta::Float64 return theta end - init_param!(foo, :theta, 0.) + register_parameters!(foo, [:theta]) + + init_parameter!((foo, :theta), 0.0) + (trace, ) = generate(foo, ()) accumulate_param_gradients!(trace, 1.) - conf = GradientDescent(0.001, 1000) - state = Gen.init_update_state(conf, foo, [:theta]) - Gen.apply_update!(state) - @test isapprox(get_param(foo, :theta), 0.001) - @test isapprox(get_param_grad(foo, :theta), 0.) + @test isapprox(get_gradient((foo, :theta)), 1.0) + conf = DecayStepGradientDescent(0.001, 1000) + optimizer = init_optimizer(conf, [(foo, :theta)]) + apply_update!(optimizer) + @test isapprox(get_parameter_value((foo, :theta)), 0.001) + @test isapprox(get_gradient((foo, :theta)), 0.0) end @testset "multi-component addresses" begin diff --git a/test/inference/importance_sampling.jl b/test/inference/importance_sampling.jl index 4a4842b21..c2e9ac413 100644 --- a/test/inference/importance_sampling.jl +++ b/test/inference/importance_sampling.jl @@ -15,33 +15,36 @@ n = 4 - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, n; multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, proposal, (), n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, proposal, (), n; + multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end (trace, lml_est) = importance_resampling(model, (), observations, n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y (trace, lml_est) = importance_resampling(model, (), observations, proposal, (), n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y end - - diff --git a/test/inference/train.jl b/test/inference/train.jl index 9e3aee7a7..4a09c217e 100644 --- a/test/inference/train.jl +++ b/test/inference/train.jl @@ -48,6 +48,7 @@ end @trace(bernoulli(prob_y), :y) end + register_parameters!(student, [:theta1, :theta2, :theta3, :theta4, :theta5]) function data_generator() (choices, _, retval) = propose(teacher, ()) @@ -64,11 +65,11 @@ # theta2 -> inf (prob_y -> 1) # theta3 -> -inf (prob_y -> 0) - init_param!(student, :theta1, 0.) - init_param!(student, :theta2, 0.) - init_param!(student, :theta3, 0.) - init_param!(student, :theta4, 0.) - init_param!(student, :theta5, 0.) + init_parameter!((student, :theta1), 0.0) + init_parameter!((student, :theta2), 0.0) + init_parameter!((student, :theta3), 0.0) + init_parameter!((student, :theta4), 0.0) + init_parameter!((student, :theta5), 0.0) # check gradients using finite differences on a simulated batch minibatch_size = 100 @@ -80,12 +81,13 @@ accumulate_param_gradients!(student_trace, nothing) end for name in [:theta1, :theta2, :theta3, :theta4, :theta5] - actual = get_param_grad(student, name) + actual = get_gradient((student, name)) dx = 1e-6 - value = get_param(student, name) + value = get_parameter_value((student, name)) # evaluate total log density at value + dx - set_param!(student, name, value + dx) + init_parameter!((student, name), value + dx) + lpdf_pos = 0. for i=1:minibatch_size (incr, _) = assess(student, inputs[i], constraints[i]) @@ -93,7 +95,7 @@ end # evaluate total log density at value - dx - set_param!(student, name, value - dx) + init_parameter!((student, name), value - dx) lpdf_neg = 0. for i=1:minibatch_size (incr, _) = assess(student, inputs[i], constraints[i]) @@ -103,23 +105,23 @@ expected = (lpdf_pos - lpdf_neg) / (2 * dx) @test isapprox(actual, expected, atol=1e-4) - set_param!(student, name, value) + init_parameter!((student, name), value) end # use stochastic gradient descent - update = ParamUpdate(GradientDescent(0.01, 1000000), student) - train!(student, data_generator, update, + optimizer = init_optimizer(DecayStepGradientDescent(0.01, 1000000), student) + train!(student, data_generator, optimizer, num_epoch=2000, epoch_size=50, num_minibatch=1, minibatch_size=50, verbose=false) # p(x | z=0) = p(x | z=1) = 0.5 - @test isapprox(get_param(student, :theta1), 0., atol=0.2) + @test isapprox(get_parameter_value((student, :theta1)), 0.0, atol=0.2) # y | z, x = xor(x, z) - @test get_param(student, :theta2) < -5 - @test get_param(student, :theta3) > 5 - @test get_param(student, :theta4) > 5 - @test get_param(student, :theta5) < -5 + @test get_parameter_value((student, :theta2)) < -5 + @test get_parameter_value((student, :theta3)) > 5 + @test get_parameter_value((student, :theta4)) > 5 + @test get_parameter_value((student, :theta5)) < -5 end @@ -141,15 +143,16 @@ end z = @trace(normal(x + theta, exp(log_std)), :z) return z end + register_parameters!(q, [:theta, :log_std]) # train simple q using lecture! to compute gradients - init_param!(q, :theta, 0.) - init_param!(q, :log_std, 0.) - update = ParamUpdate(FixedStepGradientDescent(1e-4), q) + init_parameter!((q, :theta), 0.0) + init_parameter!((q, :log_std), 0.0) + optimizer = init_optimizer(FixedStepGradientDescent(1e-4), q) score = Inf for iter=1:100 score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:1000]) / 1000 - apply!(update) + apply_update!(optimizer) end score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:10000]) / 10000 @test isapprox(score, -0.21, atol=5e-2) @@ -163,15 +166,16 @@ end @trace(normal(means[i], exp(log_std)), i => :z) end end + register_parameters!(q_batched, [:theta, :log_std]) # train simple q using lecture_batched! to compute gradients - init_param!(q_batched, :theta, 0.) - init_param!(q_batched, :log_std, 0.) - update = ParamUpdate(FixedStepGradientDescent(0.001), q_batched) + init_parameter!((q_batched, :theta), 0.0) + init_parameter!((q_batched, :log_std), 0.0) + optimizer = init_optimizer(FixedStepGradientDescent(0.001), q_batched) score = Inf for iter=1:100 score = lecture_batched!(p, (), q_batched, trs -> (map(tr -> tr[:x], trs),), 1000) - apply!(update) + apply_update!(optimizer) end score = sum([lecture!(p, (), q, tr -> (tr[:x],)) for _=1:10000]) / 10000 @test isapprox(score, -0.21, atol=5e-2) diff --git a/test/inference/variational.jl b/test/inference/variational.jl index babcfb481..f44070873 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -15,41 +15,48 @@ @trace(normal(slope_mu, exp(slope_log_std)), :slope) @trace(normal(intercept_mu, exp(intercept_log_std)), :intercept) end - - # to regular black box variational inference - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) + register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std]) observations = choicemap() - update = ParamUpdate(GradientDescent(1, 100000), approx) - update = ParamUpdate(GradientDescent(1., 1000), approx) - black_box_vi!(model, (), observations, approx, (), update; - iters=2000, samples_per_iter=100, verbose=false) - slope_mu = get_param(approx, :slope_mu) - slope_log_std = get_param(approx, :slope_log_std) - intercept_mu = get_param(approx, :intercept_mu) - intercept_log_std = get_param(approx, :intercept_log_std) - @test isapprox(slope_mu, -1., atol=0.001) - @test isapprox(slope_log_std, 0.5, atol=0.001) - @test isapprox(intercept_mu, 1., atol=0.001) - @test isapprox(intercept_log_std, 2.0, atol=0.001) + optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx) + + # test regular black box variational inference + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vi!(model, (), observations, approx, (), optimizer; + iters=2000, samples_per_iter=100, verbose=false, multithreaded=multithreaded) + + slope_mu = get_parameter_value((approx, :slope_mu)) + slope_log_std = get_parameter_value((approx, :slope_log_std)) + intercept_mu = get_parameter_value((approx, :intercept_mu)) + intercept_log_std = get_parameter_value((approx, :intercept_log_std)) + @test isapprox(slope_mu, -1., atol=0.001) + @test isapprox(slope_log_std, 0.5, atol=0.001) + @test isapprox(intercept_mu, 1., atol=0.001) + @test isapprox(intercept_log_std, 2.0, atol=0.001) + end # smoke test for black box variational inference with Monte Carlo objectives - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) - black_box_vimco!(model, (), observations, approx, (), update, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=false) - - init_param!(approx, :slope_mu, 0.) - init_param!(approx, :slope_log_std, 0.) - init_param!(approx, :intercept_mu, 0.) - init_param!(approx, :intercept_log_std, 0.) - black_box_vimco!(model, (), observations, approx, (), update, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=true) + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=false, + multithreaded=multithreaded) + + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=true, + multithreaded=multithreaded) + end end @@ -66,6 +73,7 @@ end {(:x, i)} ~ normal(z, 1) end end + register_parameters!(model, [:theta]) @gen function approx(xs) @param mu_coeffs::Vector{Float64} # 2 x 1; should be [opt_theta / 2, 0.5] @@ -75,6 +83,7 @@ end {(:z, i)} ~ normal(mu, exp(log_std)) end end + register_parameters!(approx, [:mu_coeffs, :log_std]) observations = choicemap() xs = Float64[] @@ -97,7 +106,7 @@ end {(:z, i)} ~ normal(posterior_means[i], sqrt(1.0 / posterior_precisions)) end end - init_param!(model, :theta, opt_theta) + init_parameter!((model, :theta), opt_theta) approx_trace = simulate(optimum_approx, ()) (model_trace, _) = generate(model, (), merge(get_choices(approx_trace), observations)) # note that p(z1..zn, x1..xn) / p(z1..zn | x1..xn) = p(x1...xn) - for all z1..zn @@ -105,31 +114,31 @@ end println("true optimum log_marginal_likelihood: $log_marginal_likelihood") # using BBVI with score function estimator - init_param!(model, :theta, 0.0) - init_param!(approx, :mu_coeffs, zeros(2)) - init_param!(approx, :log_std, 0.0) - approx_update = ParamUpdate(FixedStepGradientDescent(0.0001), approx) - model_update = ParamUpdate(FixedStepGradientDescent(0.002), model) + init_parameter!((model, :theta), 0.0) + init_parameter!((approx, :mu_coeffs), zeros(2)) + init_parameter!((approx, :log_std), 0.0) + approx_optimizer = init_optimizer(FixedStepGradientDescent(0.0001), approx) + model_optimizer = init_optimizer(FixedStepGradientDescent(0.002), model) @time (_, _, elbo_history, _) = - black_box_vi!(model, (), model_update, observations, - approx, (xs,), approx_update; + black_box_vi!(model, (), model_optimizer, observations, + approx, (xs,), approx_optimizer; iters=3000, samples_per_iter=20, verbose=false) - @test isapprox(get_param(model, :theta), opt_theta, atol=1e-1) - println("final theta: $(get_param(model, :theta))") + @test isapprox(get_parameter_value((model, :theta)), opt_theta, atol=1e-1) + println("final theta: $(get_parameter_value((model, :theta)))") println("final elbo estimate: $(elbo_history[end])") @test isapprox(elbo_history[end], log_marginal_likelihood, rtol=0.1) # using VIMCO - init_param!(model, :theta, 0.0) - init_param!(approx, :mu_coeffs, zeros(2)) - init_param!(approx, :log_std, 0.0) - approx_update = ParamUpdate(FixedStepGradientDescent(0.001), approx) - model_update = ParamUpdate(FixedStepGradientDescent(0.01), model) + init_parameter!((model, :theta), 0.0) + init_parameter!((approx, :mu_coeffs), zeros(2)) + init_parameter!((approx, :log_std), 0.0) + approx_optimizer = init_optimizer(FixedStepGradientDescent(0.001), approx) + model_optimizer = init_optimizer(FixedStepGradientDescent(0.01), model) @time (_, _, elbo_history, _) = - black_box_vimco!(model, (), model_update, observations, - approx, (xs,), approx_update, 10; + black_box_vimco!(model, (), model_optimizer, observations, + approx, (xs,), approx_optimizer, 10; iters=1000, samples_per_iter=10, verbose=false) - println("final theta: $(get_param(model, :theta))") + println("final theta: $(get_parameter_value((model, :theta)))") println("final elbo estimate: $(elbo_history[end])") @test isapprox(elbo_history[end], log_marginal_likelihood, rtol=0.1) end diff --git a/test/modeling_library/map.jl b/test/modeling_library/map.jl index 3c1f820fe..6f4476b31 100644 --- a/test/modeling_library/map.jl +++ b/test/modeling_library/map.jl @@ -5,8 +5,9 @@ z = @trace(normal(x + y, std), :z) return z end + register_parameters!(foo, [:std]) - set_param!(foo, :std, 1.) + init_parameter!((foo, :std), 1.0) bar = Map(foo) xs = [1.0, 2.0, 3.0, 4.0] @@ -393,13 +394,13 @@ # get gradients wrt xs and ys trace = get_initial_trace() - zero_param_grad!(foo, :std) + reset_gradient!((foo, :std)) input_grads = accumulate_param_gradients!(trace, retval_grad) @test isapprox(input_grads[1], expected_xs_grad) @test isapprox(input_grads[2], expected_ys_grad) expected_std_grad = (logpdf_grad(normal, z1, 4., 1.)[3] + logpdf_grad(normal, z2, 6., 1.)[3]) - @test isapprox(get_param_grad(foo, :std), expected_std_grad) + @test isapprox(get_gradient((foo, :std)), expected_std_grad) end end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 8c183aa16..520c9879a 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -112,13 +112,15 @@ z = @trace(normal(x + y, std), :z) return z end - init_param!(bang1, :std, 3.0) + register_parameters!(bang1, [:std]) + init_parameter!((bang1, :std), 3.0) @gen (grad) function fuzz1((grad)(x::Float64), (grad)(y::Float64)) @param(std::Float64) z = @trace(normal(x + 2 * y, std), :z) return z end - init_param!(fuzz1, :std, 3.0) + register_parameters!(fuzz1, [:std]) + init_parameter!((fuzz1, :std), 3.0) sc = Switch(bang1, fuzz1) @gen (grad) function bam(s::Int) x ~ sc(s, 5.0, 3.0) @@ -195,15 +197,15 @@ for z in [1.0, 3.0, 5.0, 10.0] chm = choicemap((:z, z)) tr, _ = generate(bam, (1, ), chm) - zero_param_grad!(bang1, :std) + reset_gradient!((bang1, :std)) input_grads = accumulate_param_gradients!(tr, 1.0) expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3] - @test isapprox(get_param_grad(bang1, :std), expected_std_grad) + @test isapprox(get_gradient((bang1, :std)), expected_std_grad) tr, _ = generate(bam, (2, ), chm) - zero_param_grad!(fuzz1, :std) + reset_gradient!((fuzz1, :std)) input_grads = accumulate_param_gradients!(tr, 1.0) expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3] - @test isapprox(get_param_grad(fuzz1, :std), expected_std_grad) + @test isapprox(get_gradient((fuzz1, :std)), expected_std_grad) end end diff --git a/test/modeling_library/unfold.jl b/test/modeling_library/unfold.jl index 3d8c1b54a..2b4aaec7a 100644 --- a/test/modeling_library/unfold.jl +++ b/test/modeling_library/unfold.jl @@ -487,11 +487,12 @@ x = @trace(normal(x_prev * alpha + beta, std), :x) return x end + register_parameters!(kernel, [:std]) foo = Unfold(kernel) std = 1. - set_param!(kernel, :std, std) + init_parameter!((kernel, :std), std) x_init = 0.1 alpha = 0.2 @@ -504,7 +505,7 @@ constraints[2 => :x] = x2 (trace, _) = generate(foo, (2, x_init, alpha, beta), constraints) - zero_param_grad!(kernel, :std) + reset_gradient!((kernel, :std)) input_grads = accumulate_param_gradients!(trace, nothing) @test input_grads[1] == nothing # length @test input_grads[2] == nothing # inital state @@ -512,7 +513,7 @@ #@test isapprox(input_grads[4], expected_ys_grad) # beta expected_std_grad = (logpdf_grad(normal, x1, x_init * alpha + beta, std)[3] + logpdf_grad(normal, x2, x1 * alpha + beta, std)[3]) - @test isapprox(get_param_grad(kernel, :std), expected_std_grad) + @test isapprox(get_gradient((kernel, :std)), expected_std_grad) end @gen (grad) function ker(t, (grad)(x::Float64)) diff --git a/test/optimization.jl b/test/optimization.jl new file mode 100644 index 000000000..6e531e71b --- /dev/null +++ b/test/optimization.jl @@ -0,0 +1,97 @@ +@testset "optimization" begin + +@testset "Julia parameter store" begin + + store = JuliaParameterStore() + + @gen function foo() + @param theta::Float64 + @param phi::Vector{Float64} + end + register_parameters!(foo, [:theta, :phi]) + + # before the parameters are initialized in the store + + @test Gen.get_local_parameters(store, foo) == Dict{Symbol,Any}() + + @test_throws KeyError get_gradient((foo, :theta), store) + @test_throws KeyError get_parameter_value((foo, :theta), store) + @test_throws KeyError increment_gradient!((foo, :theta), 1.0, store) + @test_throws KeyError reset_gradient!((foo, :theta), store) + @test_throws KeyError Gen.set_parameter_value!((foo, :theta), 1.0, store) + @test_throws KeyError Gen.get_gradient_accumulator((foo, :theta), store) + + @test_throws KeyError get_gradient((foo, :phi), store) + @test_throws KeyError get_parameter_value((foo, :phi), store) + @test_throws KeyError increment_gradient!((foo, :phi), [1.0, 1.0], store) + @test_throws KeyError reset_gradient!((foo, :phi), store) + @test_throws KeyError Gen.set_parameter_value!((foo, :phi), [1.0, 1.0], store) + @test_throws KeyError Gen.get_gradient_accumulator((foo, :phi), store) + + # after the parameters are initialized in the store + + init_parameter!((foo, :theta), 1.0, store) + init_parameter!((foo, :phi), [1.0, 2.0], store) + + dict = Gen.get_local_parameters(store, foo) + @test length(dict) == 2 + @test dict[:theta] == 1.0 + @test dict[:phi] == [1.0, 2.0] + + @test get_gradient((foo, :theta), store) == 0.0 + @test get_parameter_value((foo, :theta), store) == 1.0 + increment_gradient!((foo, :theta), 1.1, store) + @test get_gradient((foo, :theta), store) == 1.1 + increment_gradient!((foo, :theta), 1.1, 2.0, store) + @test get_gradient((foo, :theta), store) == (1.1 + 2.2) + reset_gradient!((foo, :theta), store) + @test get_gradient((foo, :theta), store) == 0.0 + Gen.set_parameter_value!((foo, :theta), 2.0, store) + @test get_parameter_value((foo, :theta), store) == 2.0 + @test get_value(Gen.get_gradient_accumulator((foo, :theta), store)) == 0.0 + + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + @test get_parameter_value((foo, :phi), store) == [1.0, 2.0] + increment_gradient!((foo, :phi), [1.1, 1.2], store) + @test get_gradient((foo, :phi), store) == [1.1, 1.2] + increment_gradient!((foo, :phi), [1.1, 1.2], 2.0, store) + @test get_gradient((foo, :phi), store) == ([1.1, 1.2] .+ (2.0 * [1.1, 1.2])) + reset_gradient!((foo, :phi), store) + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + Gen.set_parameter_value!((foo, :phi), [2.0, 3.0], store) + @test get_parameter_value((foo, :phi), store) == [2.0, 3.0] + @test Gen.get_value(Gen.get_gradient_accumulator((foo, :phi), store)) == [0.0, 0.0] + + # check that the default global Julia store was unaffected + @test_throws KeyError get_parameter_value((foo, :theta)) + @test_throws KeyError get_gradient((foo, :theta)) + @test_throws KeyError increment_gradient!((foo, :theta), 1.0) + + # FixedStepGradientDescent + init_parameter!((foo, :theta), 1.0, store) + init_parameter!((foo, :phi), [1.0, 2.0], store) + increment_gradient!((foo, :theta), 2.0, store) + increment_gradient!((foo, :phi), [1.0, 3.0], store) + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :theta)], store) + apply_update!(optimizer) # update just theta + @test get_gradient((foo, :theta), store) == 0.0 + @test get_parameter_value((foo, :theta), store) == 1.0 + (2.0 * 1e-2) + @test get_gradient((foo, :phi), store) == [1.0, 3.0] # unchanged + @test get_parameter_value((foo, :phi), store) == [1.0, 2.0] # unchanged + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :phi)], store) + apply_update!(optimizer) # update just phi + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + @test get_parameter_value((foo, :phi), store) == ([1.0, 2.0] .+ 1e-2 * [1.0, 3.0]) + + # DecayStepGradientDescent + # TODO + + # default_parameter_context and default_julia_parameter_store +end + +@testset "composite optimizer" begin + +end + + +end diff --git a/test/optional_args.jl b/test/optional_args.jl index fd6c4ea71..9595c5706 100644 --- a/test/optional_args.jl +++ b/test/optional_args.jl @@ -8,9 +8,10 @@ using Gen b = @trace(normal(a+theta, 1), :b) return (x, y, z, x+y+z) end + register_parameters!(foo, [:theta]) # initialize theta to zero for non-gradient tests - init_param!(foo, :theta, 0.) + init_parameter!((foo, :theta), 0.0) # test directly calling with varying args @test foo(1) == (1, 2, 3, 6) diff --git a/test/runtests.jl b/test/runtests.jl index 1317b19ee..fd8f9ebd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -76,6 +76,7 @@ end const dx = 1e-6 +include("optimization.jl") include("autodiff.jl") include("diff.jl") include("selection.jl") @@ -87,5 +88,3 @@ include("static_ir/static_ir.jl") include("tilde_sugar.jl") include("inference/inference.jl") include("modeling_library/modeling_library.jl") - - diff --git a/test/static_ir/gradients.jl b/test/static_ir/gradients.jl new file mode 100644 index 000000000..13cd15dc8 --- /dev/null +++ b/test/static_ir/gradients.jl @@ -0,0 +1,117 @@ +@testset "backprop" begin + + #@gen (static) function bar(mu_z::Float64) + #z = @trace(normal(mu_z, 1), :z) + #return z + mu_z + #end + + # bar + builder = StaticIRBuilder() + mu_z = add_argument_node!(builder, name=:mu_z, typ=:Float64, compute_grad=true) + one = add_constant_node!(builder, 1.) + z = add_addr_node!(builder, normal, inputs=[mu_z, one], addr=:z, name=:z) + retval = add_julia_node!(builder, (z, mu_z) -> z + mu_z, inputs=[z, mu_z], name=:retval) + set_return_node!(builder, retval) + ir = build_ir(builder) + bar = eval(generate_generative_function(ir, :bar, track_diffs=false, cache_julia_nodes=false)) + + #@gen (static) function foo(mu_a::Float64) + #param theta::Float64 + #a = @trace(normal(mu_a, 1), :a) + #b = @trace(normal(a, 1), :b) + #bar = @trace(bar(a), :bar) + #c = a * b * bar * theta + #out = @trace(normal(c, 1), :out) + #return out + #end + + # foo + builder = StaticIRBuilder() + mu_a = add_argument_node!(builder, name=:mu_a, typ=:Float64, compute_grad=true) + theta = add_trainable_param_node!(builder, :theta, typ=QuoteNode(Float64)) + one = add_constant_node!(builder, 1.) + a = add_addr_node!(builder, normal, inputs=[mu_a, one], addr=:a, name=:a) + b = add_addr_node!(builder, normal, inputs=[a, one], addr=:b, name=:b) + bar_val = add_addr_node!(builder, bar, inputs=[a], addr=:bar, name=:bar_val) + c = add_julia_node!(builder, (a, b, bar, theta) -> (a * b * bar * theta), + inputs=[a, b, bar_val, theta], name=:c) + retval = add_addr_node!(builder, normal, inputs=[c, one], addr=:out, name=:out) + set_return_node!(builder, retval) + ir = build_ir(builder) + foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia_nodes=false)) + + Gen.load_generated_functions() + + # test get_parameters + store_to_ids = Gen.get_parameters(foo, Gen.default_parameter_context) + @test length(store_to_ids) == 1 + @test length(store_to_ids[Gen.default_julia_parameter_store]) == 1 + @test (foo, :theta) in store_to_ids[Gen.default_julia_parameter_store] + + function f(mu_a, theta, a, b, z, out) + lpdf = 0. + mu_z = a + lpdf += logpdf(normal, z, mu_z, 1) + lpdf += logpdf(normal, a, mu_a, 1) + lpdf += logpdf(normal, b, a, 1) + c = a * b * (z + mu_z) * theta + lpdf += logpdf(normal, out, c, 1) + return lpdf + 2 * out + end + + mu_a = 1. + theta = -0.5 + a = 2. + b = 3. + z = 4. + out = 5. + + # initialize the trainable parameter + init_parameter!((foo, :theta), theta) + + # get the initial trace + constraints = choicemap() + constraints[:a] = a + constraints[:b] = b + constraints[:out] = out + constraints[:bar => :z] = z + (trace, _) = generate(foo, (mu_a,), constraints) + + # compute gradients with choice_gradients + selection = select(:bar => :z, :a, :out) + selection = StaticSelection(selection) + retval_grad = 2. + ((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) + + # check input gradient + @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) + + # check value trie + @test get_value(value_trie, :a) == a + @test get_value(value_trie, :out) == out + @test get_value(value_trie, :bar => :z) == z + @test !has_value(value_trie, :b) # was not selected + @test length(get_submaps_shallow(value_trie)) == 1 + @test length(get_values_shallow(value_trie)) == 2 + + # check gradient trie + @test length(get_submaps_shallow(gradient_trie)) == 1 + @test length(get_values_shallow(gradient_trie)) == 2 + @test !has_value(gradient_trie, :b) # was not selected + @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) + @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) + @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) + + # compute gradients with accumulate_param_gradients! + retval_grad = 2. + (mu_a_grad,) = accumulate_param_gradients!(trace, retval_grad) + + # check input gradient + @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) + + # check trainable parameter gradient + @test isapprox( + get_gradient((foo, :theta)), + finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) + +end diff --git a/test/static_ir/static_ir.jl b/test/static_ir/static_ir.jl index f773b6853..61a14a58c 100644 --- a/test/static_ir/static_ir.jl +++ b/test/static_ir/static_ir.jl @@ -49,7 +49,7 @@ foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia @test occursin("== Static IR ==", repr("text/plain", ir)) theta_val = rand() -set_param!(foo, :theta, theta_val) +init_parameter!((foo, :theta), theta_val) #@gen (static, nojuliacache) function const_fn() #return 1 @@ -276,122 +276,8 @@ end end -@testset "backprop" begin - - #@gen (static) function bar(mu_z::Float64) - #z = @trace(normal(mu_z, 1), :z) - #return z + mu_z - #end - - # bar - builder = StaticIRBuilder() - mu_z = add_argument_node!(builder, name=:mu_z, typ=:Float64, compute_grad=true) - one = add_constant_node!(builder, 1.) - z = add_addr_node!(builder, normal, inputs=[mu_z, one], addr=:z, name=:z) - retval = add_julia_node!(builder, (z, mu_z) -> z + mu_z, inputs=[z, mu_z], name=:retval) - set_return_node!(builder, retval) - ir = build_ir(builder) - bar = eval(generate_generative_function(ir, :bar, track_diffs=false, cache_julia_nodes=false)) - - #@gen (static) function foo(mu_a::Float64) - #param theta::Float64 - #a = @trace(normal(mu_a, 1), :a) - #b = @trace(normal(a, 1), :b) - #bar = @trace(bar(a), :bar) - #c = a * b * bar * theta - #out = @trace(normal(c, 1), :out) - #return out - #end - - # foo - builder = StaticIRBuilder() - mu_a = add_argument_node!(builder, name=:mu_a, typ=:Float64, compute_grad=true) - theta = add_trainable_param_node!(builder, :theta, typ=QuoteNode(Float64)) - one = add_constant_node!(builder, 1.) - a = add_addr_node!(builder, normal, inputs=[mu_a, one], addr=:a, name=:a) - b = add_addr_node!(builder, normal, inputs=[a, one], addr=:b, name=:b) - bar_val = add_addr_node!(builder, bar, inputs=[a], addr=:bar, name=:bar_val) - c = add_julia_node!(builder, (a, b, bar, theta) -> (a * b * bar * theta), - inputs=[a, b, bar_val, theta], name=:c) - retval = add_addr_node!(builder, normal, inputs=[c, one], addr=:out, name=:out) - set_return_node!(builder, retval) - ir = build_ir(builder) - foo = eval(generate_generative_function(ir, :foo, track_diffs=false, cache_julia_nodes=false)) - - Gen.load_generated_functions() - - function f(mu_a, theta, a, b, z, out) - lpdf = 0. - mu_z = a - lpdf += logpdf(normal, z, mu_z, 1) - lpdf += logpdf(normal, a, mu_a, 1) - lpdf += logpdf(normal, b, a, 1) - c = a * b * (z + mu_z) * theta - lpdf += logpdf(normal, out, c, 1) - return lpdf + 2 * out - end - - mu_a = 1. - theta = -0.5 - a = 2. - b = 3. - z = 4. - out = 5. - - init_param!(foo, :theta, theta) - - @test get_param(foo, :theta) == theta - @test get_param_grad(foo, :theta) == 0. - - # get the initial trace - constraints = choicemap() - constraints[:a] = a - constraints[:b] = b - constraints[:out] = out - constraints[:bar => :z] = z - (trace, _) = generate(foo, (mu_a,), constraints) - - # compute gradients with choice_gradients - selection = select(:bar => :z, :a, :out) - selection = StaticSelection(selection) - retval_grad = 2. - ((mu_a_grad,), value_trie, gradient_trie) = choice_gradients(trace, selection, retval_grad) - - # check input gradient - @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) - - # check value trie - @test get_value(value_trie, :a) == a - @test get_value(value_trie, :out) == out - @test get_value(value_trie, :bar => :z) == z - @test !has_value(value_trie, :b) # was not selected - @test length(get_submaps_shallow(value_trie)) == 1 - @test length(get_values_shallow(value_trie)) == 2 - - # check gradient trie - @test length(get_submaps_shallow(gradient_trie)) == 1 - @test length(get_values_shallow(gradient_trie)) == 2 - @test !has_value(gradient_trie, :b) # was not selected - @test isapprox(get_value(gradient_trie, :a), finite_diff(f, (mu_a, theta, a, b, z, out), 3, dx)) - @test isapprox(get_value(gradient_trie, :out), finite_diff(f, (mu_a, theta, a, b, z, out), 6, dx)) - @test isapprox(get_value(gradient_trie, :bar => :z), finite_diff(f, (mu_a, theta, a, b, z, out), 5, dx)) - - # reset the trainable parameter gradient - zero_param_grad!(foo, :theta) - @test get_param(foo, :theta) == theta - @test get_param_grad(foo, :theta) == 0. - - # compute gradients with accumulate_param_gradients! - retval_grad = 2. - (mu_a_grad,) = accumulate_param_gradients!(trace, retval_grad) - - # check input gradient - @test isapprox(mu_a_grad, finite_diff(f, (mu_a, theta, a, b, z, out), 1, dx)) - - # check trainable parameter gradient - @test isapprox(get_param_grad(foo, :theta), finite_diff(f, (mu_a, theta, a, b, z, out), 2, dx)) +include("gradients.jl") -end # functions to test tracked diffs