Skip to content

Commit

Permalink
add flag to svi
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Feb 7, 2024
1 parent 449df7c commit 346d71b
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,14 @@ def get_params(self, svi_state):
params = self.constrain_fn(self.optim.get_params(svi_state.optim_state))
return params

def update(self, svi_state, *args, **kwargs):
def update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs):
"""
Take a single step of SVI (possibly on a batch / minibatch of data),
using the optimizer.
:param svi_state: current state of SVI.
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
Defaults to False.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
Expand All @@ -281,16 +283,18 @@ def update(self, svi_state, *args, **kwargs):
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_update(
loss_fn, svi_state.optim_state
loss_fn, svi_state.optim_state, forward_mode_differentiation
)
return SVIState(optim_state, mutable_state, rng_key), loss_val

def stable_update(self, svi_state, *args, **kwargs):
def stable_update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs):
"""
Similar to :meth:`update` but returns the current state if the
the loss or the new state contains invalid values.
:param svi_state: current state of SVI.
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
Defaults to False.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
Expand All @@ -310,7 +314,7 @@ def stable_update(self, svi_state, *args, **kwargs):
mutable_state=svi_state.mutable_state,
)
(loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update(
loss_fn, svi_state.optim_state
loss_fn, svi_state.optim_state, forward_mode_differentiation
)
return SVIState(optim_state, mutable_state, rng_key), loss_val

Expand All @@ -321,6 +325,7 @@ def run(
*args,
progress_bar=True,
stable_update=False,
forward_mode_differentiation=False,
init_state=None,
init_params=None,
**kwargs,
Expand All @@ -342,6 +347,8 @@ def run(
``True``.
:param bool stable_update: whether to use :meth:`stable_update` to update
the state. Defaults to False.
:param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation.
Defaults to False.
:param SVIState init_state: if not None, begin SVI from the
final state of previous SVI run. Usage::
Expand All @@ -365,9 +372,9 @@ def run(

def body_fn(svi_state, _):
if stable_update:
svi_state, loss = self.stable_update(svi_state, *args, **kwargs)
svi_state, loss = self.stable_update(svi_state, forward_mode_differentiation, *args, **kwargs)
else:
svi_state, loss = self.update(svi_state, *args, **kwargs)
svi_state, loss = self.update(svi_state, forward_mode_differentiation, *args, **kwargs)
return svi_state, loss

if init_state is None:
Expand Down

0 comments on commit 346d71b

Please sign in to comment.