You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm porting JAX code that I've been writing for research projects into more production-ready code in Pytorch. I want to use Pytorch Lightning to streamline the process but I'm not sure how well it accomodates with e.g. torch.vmap, torch.func.grad etc. Does it optimize through these operators correctly ? In particular I have some autoencoder where the encoder is a based on Langevin diffusion and Monte Carlo estimation.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all,
I'm porting JAX code that I've been writing for research projects into more production-ready code in Pytorch. I want to use Pytorch Lightning to streamline the process but I'm not sure how well it accomodates with e.g.
torch.vmap
,torch.func.grad
etc. Does it optimize through these operators correctly ? In particular I have some autoencoder where the encoder is a based on Langevin diffusion and Monte Carlo estimation.Thanks in advance!
Beta Was this translation helpful? Give feedback.
All reactions