You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In Pyro MCMC predictive, we have a parallel keyword to resolve the plate issue pyro-ppl/pyro#1995. JAX supports batching out of the box. However,
using vmap for a large num_samples and a large model will consume much more memory. I faced memory issue in training bnaf for covtype dataset and using lax.map solves the problem.
Probably we should change the keyword to sequential to avoid confusion.
The text was updated successfully, but these errors were encountered:
In Pyro MCMC predictive, we have a
parallel
keyword to resolve the plate issue pyro-ppl/pyro#1995. JAX supports batching out of the box. However,using vmap for a large num_samples and a large model will consume much more memory. I faced memory issue in training bnaf for covtype dataset and using
lax.map
solves the problem.Probably we should change the keyword to
sequential
to avoid confusion.The text was updated successfully, but these errors were encountered: