You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like
So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to
classMyModule(hk.Module):
def__call__(self, x):
x=ResNet50Jax()(x)
x=# some other module specific stuffreturnx
that I can then proceed with hk.transform as usual. I wasn't able to find an obvious way to do this. Any thoughts?
More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?
The text was updated successfully, but these errors were encountered:
tf2jax does introduce quite a bit of complexity/indirection (for example if you want to fine tune the model and take gradients through the tf code). I think it might be worth trying to take some pre-existing checkpoint and adapting it to work with Haiku.
Alternatively, if you have access to GPU/TPUs for training then we provide a training script to train a resnet50 model on imagenet.
I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like
So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to
that I can then proceed with
hk.transform
as usual. I wasn't able to find an obvious way to do this. Any thoughts?More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?
The text was updated successfully, but these errors were encountered: