From 346d71bbd64fc6ada8e8fee19cfc710a61044ae2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 7 Feb 2024 21:58:02 +0100 Subject: [PATCH] add flag to svi --- numpyro/infer/svi.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index 4b99302cc..6636ad118 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -256,12 +256,14 @@ def get_params(self, svi_state): params = self.constrain_fn(self.optim.get_params(svi_state.optim_state)) return params - def update(self, svi_state, *args, **kwargs): + def update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): """ Take a single step of SVI (possibly on a batch / minibatch of data), using the optimizer. :param svi_state: current state of SVI. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary @@ -281,16 +283,18 @@ def update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val - def stable_update(self, svi_state, *args, **kwargs): + def stable_update(self, svi_state, forward_mode_differentiation=False, *args, **kwargs): """ Similar to :meth:`update` but returns the current state if the the loss or the new state contains invalid values. :param svi_state: current state of SVI. + :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. + Defaults to False. :param args: arguments to the model / guide (these can possibly vary during the course of fitting). :param kwargs: keyword arguments to the model / guide (these can possibly vary @@ -310,7 +314,7 @@ def stable_update(self, svi_state, *args, **kwargs): mutable_state=svi_state.mutable_state, ) (loss_val, mutable_state), optim_state = self.optim.eval_and_stable_update( - loss_fn, svi_state.optim_state + loss_fn, svi_state.optim_state, forward_mode_differentiation ) return SVIState(optim_state, mutable_state, rng_key), loss_val @@ -321,6 +325,7 @@ def run( *args, progress_bar=True, stable_update=False, + forward_mode_differentiation=False, init_state=None, init_params=None, **kwargs, @@ -342,6 +347,8 @@ def run( ``True``. :param bool stable_update: whether to use :meth:`stable_update` to update the state. Defaults to False. + :param bool forward_mode_differentiation: flag indicating whether to use forward mode differentiation. + Defaults to False. :param SVIState init_state: if not None, begin SVI from the final state of previous SVI run. Usage:: @@ -365,9 +372,9 @@ def run( def body_fn(svi_state, _): if stable_update: - svi_state, loss = self.stable_update(svi_state, *args, **kwargs) + svi_state, loss = self.stable_update(svi_state, forward_mode_differentiation, *args, **kwargs) else: - svi_state, loss = self.update(svi_state, *args, **kwargs) + svi_state, loss = self.update(svi_state, forward_mode_differentiation, *args, **kwargs) return svi_state, loss if init_state is None: