Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct way to integrate tf2jax output with a hk.Module #582

Open
rdilip opened this issue Dec 11, 2022 · 1 comment
Open

Correct way to integrate tf2jax output with a hk.Module #582

rdilip opened this issue Dec 11, 2022 · 1 comment

Comments

@rdilip
Copy link

rdilip commented Dec 11, 2022

I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like

import tf2jax
import tensorflow as tf
import jax.numpy as jnp
jax_func, jax_params = tf2jax.convert(tf.function(tf.keras.applications.resnet50.ResNet50()), jnp.zeros((1, 224, 224, 3)))

So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to

class MyModule(hk.Module):
    def __call__(self, x):
        x = ResNet50Jax()(x)
        x = # some other module specific stuff
        return x

that I can then proceed with hk.transform as usual. I wasn't able to find an obvious way to do this. Any thoughts?

More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?

@tomhennigan
Copy link
Collaborator

Hi @rdilip, here is an example of integrating Haiku and tf2jax: https://colab.research.google.com/gist/tomhennigan/5a6a264bccbbe8ecac1b475ad8049c72/example-of-using-tf2jax-with-haiku.ipynb

tf2jax does introduce quite a bit of complexity/indirection (for example if you want to fine tune the model and take gradients through the tf code). I think it might be worth trying to take some pre-existing checkpoint and adapting it to work with Haiku.

Alternatively, if you have access to GPU/TPUs for training then we provide a training script to train a resnet50 model on imagenet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants