Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model training fails with current numpyro #5

Closed
smodlich opened this issue Sep 2, 2021 · 3 comments
Closed

Model training fails with current numpyro #5

smodlich opened this issue Sep 2, 2021 · 3 comments

Comments

@smodlich
Copy link

smodlich commented Sep 2, 2021

Hi,

Thanks for the great blogpost + code. I've tried to run the model training with numpyro==0.7.2. Preprocessing runs, but model training fails with "NotImplementedError: This ELBO objective does not support mutable state.". The line that fails is in the training notebook train_handler.fit. The error seems to originate from numpyro introducing mutable states (I think from here line 57, loss_with_mutable_state method). Maybe this commit is related?

Any idea how to fix this error?

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-8-1349482c40ae> in <module>
----> 1 train_handler.fit(X_train, n_epochs=5_000, log_freq=1_000, lr=0.1)

~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/bhm_at_scale-1.1.post0.dev3+gea40dd8.dirty-py3.8.egg/bhm_at_scale/handler.py in fit(self, X, n_epochs, log_freq, lr, **kwargs)
    130             self._fit(X, n_epochs)
    131         else:
--> 132             loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]
    133 
    134             curr_epoch = 0

~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/svi.py in evaluate(self, svi_state, *args, **kwargs)
    363         _, rng_key_eval = random.split(svi_state.rng_key)
    364         params = self.get_params(svi_state)
--> 365         return self.loss.loss(
    366             rng_key_eval,
    367             params,

~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
     44         :return: negative of the Evidence Lower Bound (ELBO) to be minimized.
     45         """
---> 46         return self.loss_with_mutable_state(
     47             rng_key, param_map, model, guide, *args, **kwargs
     48         )["loss"]

~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/elbo.py in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
     66         :return: a tuple of ELBO loss and the mutable state
     67         """
---> 68         raise NotImplementedError("This ELBO objective does not support mutable state.")
     69 
     70 

NotImplementedError: This ELBO objective does not support mutable state.
@FlorianWilhelm
Copy link
Owner

Hi, that's interesting. Thanks for letting me know.

Have you tried using TRACE_ELBO instead of ELBO from numpyro.infer in handler.py? If it works, a PR would be cool ;-)

@smodlich
Copy link
Author

smodlich commented Sep 3, 2021

Hi,
I tried and indeed this fixes it. Had to reinstall before changes to effect. I created a PR that also fixes another related bug when loading model.

@FlorianWilhelm
Copy link
Owner

Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants