Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable customize activation functions in clip vision encoder #1385

Merged
merged 3 commits into from
Aug 21, 2024

Conversation

Gasoonjia
Copy link
Contributor

@Gasoonjia Gasoonjia commented Aug 21, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?

This PR enables users to customize activation functions in clip vision encoder

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1385

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e4a9cb5 with merge base 9e65fa9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 21, 2024
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Gasoonjia, thanks for opening this PR. This generally makes sense to me. Out of curiosity, what is your goal for changing the activation function? Do you intend to retrain the clip encoder from scratch?

@@ -19,6 +19,7 @@ def clip_vision_encoder(
output_cls_projection: bool = False,
max_num_tiles: int = 4,
in_channels: int = 3,
hidden_act: torch.nn.Module = torch.nn.SiLU(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call this intermediate_activation to be more clear which activation in the transformer it is changing?

@@ -49,6 +50,7 @@ def clip_vision_encoder(
max_num_tiles (int): The maximum number of tiles that can be processed. This is used to
determine the size of the positional embeddings.
in_channels (int): The number of image input channels.
hidden_act (torch.nn.Module): The activation function used in the transformer layers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hidden_act (torch.nn.Module): The activation function used in the transformer layers.
hidden_act (torch.nn.Module): The activation function used in the intermediate layers in the transformer encoder

@felipemello1
Copy link
Contributor

hey @Gasoonjia, you may be interested in our upcoming flamingo PR: https://github.com/pytorch/torchtune/pull/1357/files

feel free to ping me in workplace https://fb.workplace.com/profile.php?id=61556984579937

@Gasoonjia
Copy link
Contributor Author

Gasoonjia commented Aug 21, 2024

Hey @RdoubleA thanks for comment! I've updated the PR and pls take a look!
I'm working on leveraging torchtune's (tt) module to reproduce llava1.5 based on huggingface (hf) impl, and I realize that hf is using quickgelu instead of tt's default act (SiLU), so I'd like to have a way to customize the act func I use.

@felipemello1
Copy link
Contributor

reproduce llava1.5 based on huggingface

FYI, we had some work on llava transforms: #1057

@Gasoonjia
Copy link
Contributor Author

Hey @felipemello1 glad to see you in the github! Yes I'm keeping my eye on your wonderful PR, and definetly need your help when I try to work on Flamingo!
I noticed you are updating CLIP module in that PR, will there be huge update on that? Especially on modelling side?

@Gasoonjia
Copy link
Contributor Author

reproduce llava1.5 based on huggingface

FYI, we had some work on llava transforms: #1057

Thanks for sharing! Will try to leverage your work!

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 72.81%. Comparing base (9e65fa9) to head (e4a9cb5).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1385      +/-   ##
==========================================
+ Coverage   70.57%   72.81%   +2.24%     
==========================================
  Files         272      272              
  Lines       12895    12895              
==========================================
+ Hits         9101     9390     +289     
+ Misses       3794     3505     -289     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@RdoubleA RdoubleA merged commit e568b67 into pytorch:main Aug 21, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants