Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CLIP to a functional model #5

Conversation

tirthasheshpatel
Copy link

@tirthasheshpatel tirthasheshpatel commented Apr 2, 2024

This PR updates CLIP to use the functional way of writing models in Keras.

It's currently a rough patch since I worked on it a few weeks ago. Will have to polish the overall diff and make sure I am not toughing the numerics.

Refactor includes:

  • CLIPProcessor is now a Keras layer and uses some utilities from KerasNLP to support all types of python types and array inputs
  • CLIPImageEncoder, CLIPTextEncoder, and CLIPEncoder now implement a .compute_output_shape method (required for CLIP to work with the functional API)
  • CLIPHead added to remove raw variables from the CLIP Task models; having variables in keras.Model class is tricky since functional API doesn't allow state.
  • CLIP checkpointing script has been updated to now work with the new API: new weights will be uploaded to Kaggle.

TODO: attribute KerasNLP wherever relevant
TODO: upload new weights to Kaggle
TODO: refactor the CLIPProcessor class and the CLIP class to also pull tokenizer vocab and merges from Kaggle.

@divyashreepathihalli

Divyashree Sreepathihalli and others added 5 commits April 2, 2024 20:01
Refactor includes:

- CLIPProcessor is now a Keras layer and uses some utilities from KerasNLP to support all types of python types and array inputs
- CLIPImageEncoder, CLIPTextEncoder, and CLIPEncoder now implement a `.compute_output_shape` method (required for CLIP to work with the functional API)
- CLIPHead added to remove raw variables from the CLIP Task models; having variables in `keras.Model` class is tricky since functional API doesn't allow state.
- CLIP checkpointing script has been updated to now work with the new API: new weights will be uploaded to Kaggle.

TODO: attribute KerasNLP wherever relevant
TODO: upload new weights to Kaggle
TODO: refactor the CLIPProcessor class and the CLIP class to also pull tokenizer vocab and merges from Kaggle.
@tirthasheshpatel tirthasheshpatel changed the title [WIP] Refactor CLIP to a functional model Refactor CLIP to a functional model Apr 8, 2024
@tirthasheshpatel
Copy link
Author

tirthasheshpatel commented Apr 8, 2024

The weights load properly now but the numerics are still significantly off for the same inputs. I will make some changes to the tests. Othewrwise the numerics haven't been affected by this PR. If you are happy with it anyways @divyashreepathihalli, feel free to merge this and then merge keras-team#2393

@tirthasheshpatel
Copy link
Author

tirthasheshpatel commented Apr 8, 2024

One more thing, need to update the Kaggle page with the new weights. Will share the presets with you and we can get the Kaggle page updated.

@divyashreepathihalli I noticed that the vocab and merges file need to be fetched manually from external urls. We should change that in a way that can enable us to fetch and load the preset for the processor directly from Kaggle. What do you think?

@divyashreepathihalli divyashreepathihalli marked this pull request as ready for review April 8, 2024 16:22
@divyashreepathihalli divyashreepathihalli merged commit 083481f into divyashreepathihalli:fix_clip_jax Apr 8, 2024
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants