-
Notifications
You must be signed in to change notification settings - Fork 431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLIP Text Encoder #1969
base: main
Are you sure you want to change the base?
CLIP Text Encoder #1969
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1969
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 5aa7c9f with merge base 9bafd16 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
self.sot_token = self.encoder["<|startoftext|>"] | ||
self.eot_token = self.encoder["<|endoftext|>"] | ||
self.pad_token = self.eot_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should these be configurable in the constructor like we do for other tokenizers? I don't really think it's needed imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They might not need to be, but I think it's more readable when all of the special tokens are listed at the top. I'm going to ask @RdoubleA to review this file specifically as he has the most experience with how tokenizers are used in torchtune.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping these here is fine since this is operating as a base tokenizer and not a model tokenizer with tokenize_messages
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1969 +/- ##
==========================================
+ Coverage 67.51% 67.82% +0.30%
==========================================
Files 318 323 +5
Lines 17684 17897 +213
==========================================
+ Hits 11940 12139 +199
- Misses 5744 5758 +14 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great PR! I left some comments around some standard patterns we try to follow but aside from that, this looks very solid and clean.
torchtune/modules/activations.py
Outdated
""" | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return x * torch.sigmoid(1.702 * x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this affect parity values and speed (as opposed to classic GeLU)? I'd prefer not to add this extra module if it's only ever going to be used by CLIP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the speedup of quickgelu is negligible but it's needed because using default gelu results in significantly worse parity (output MSE: 0.16 vs 0.00003)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to the clip components builder file for now then. Just a function "def quick_gelu"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved it to the clip components file but left it as a module because the activation arg of clip_mlp
/FeedForward
expects a module (I could move it to the clip/_text_encoder.py
file instead if that's better)
@@ -488,6 +489,10 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
"supported_aspect_ratios", None | |||
), | |||
) | |||
elif self._model_type == ModelType.CLIP_TEXT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably prefer the model type to just be CLIP and the convert function to be clip_hf_to_tune. We don't support the vision model yet but we plan to and I think it'd be simpler to have just one CLIP version. The clip_hf_to_tune model can raise and error for now if someone attempts to load in the vision model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how would this work? the clip_hf_to_tune
function needs to know whether it's creating torchtune parameters for the text or the vision model. and the only relevant information that the checkpointer has access to is the model_type
, right? so the model type needs to specify text or vision?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd just have a single convert mapping dictionary and map what happens to be there. I don't think the function actually needs to know which fraction of the weights it's converting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Every param will be there since HF combines the vision and text params. So the function needs to know what to ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd still like the model type to be unified to just CLIP to keep the list shorter. But I'll approve when @RdoubleA signs off on the tokenizer.
torchtune/modules/activations.py
Outdated
""" | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return x * torch.sigmoid(1.702 * x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to the clip components builder file for now then. Just a function "def quick_gelu"
@@ -488,6 +489,10 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
"supported_aspect_ratios", None | |||
), | |||
) | |||
elif self._model_type == ModelType.CLIP_TEXT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd just have a single convert mapping dictionary and map what happens to be there. I don't think the function actually needs to know which fraction of the weights it's converting.
truncate (bool): whether to truncate the text when longer than max_seq_len | ||
""" | ||
|
||
def __init__(self, path: PathLike, max_seq_len: int = 77, truncate: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: does PathLike include both str and pathlib? leaning towards keeping it just str
which is consistent with our other tokenizers
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | ||
|
||
merges = [] | ||
with open(path, encoding="utf-8") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's move this logic to a method
.replace("</w>", " ") | ||
) | ||
|
||
def __call__(self, texts: List[str]) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should return token ids as List[int]
instead of tensor, the collater usually converts to tensor. also, usually the __call__
method would be for a single sample (see example: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama3/_tokenizer.py#L330) instead of a batch of texts. The tokenizer will be in the dataset, and will automatically call it on each sample individually.
In this case, maybe the call method would only call encode
, or you don't need the call method at all. It depends on if you'll be using the same SFTDataset
or a different text + image dataset abstraction that won't rely on the Message
structure
if token in self.cache: | ||
return self.cache[token] | ||
|
||
word = tuple(token[:-1]) + (token[-1] + "</w>",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe </w>
should be a class attribute or constant since you are using it everywhere
return result | ||
|
||
def _bpe(self, token): | ||
if token in self.cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will keeping a cache lead to exploding memory usage as we tokenize more and more in the same script?
an alternative could be to use the decorator @lru_cache
, but not sure if that will degrade performance too much
first, second = bigram | ||
new_word = [] | ||
i = 0 | ||
while i < len(word): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this is straight from the official implementation, but could you add some high level comments to explain what's happening in these loops?
return dict(zip(bs, cs)) | ||
|
||
|
||
def _get_pairs(word): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type annotations
return word | ||
|
||
|
||
def _bytes_to_unicode(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return type annotations
result[i, : len(tokens)] = torch.tensor(tokens) | ||
return result | ||
|
||
def _bpe(self, token): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type annotations
self.sot_token = self.encoder["<|startoftext|>"] | ||
self.eot_token = self.encoder["<|endoftext|>"] | ||
self.pad_token = self.eot_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping these here is fine since this is operating as a base tokenizer and not a model tokenizer with tokenize_messages
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Test plan
Minimal code to run the CLIP text encoder e2e:
(first download CLIP weights:
tune download openai/clip-vit-large-patch14 --output-dir /tmp/clip-vit-large-patch14 --ignore-patterns None
)Checked parity with the HF CLIP tokenizer and text encoder as implemented here: MSE between the encoder outputs for on a batch of 32 test strings =
3.55e-5
Tokenization speed for 32 test strings
Encoding speed for a single batch of 32 test strings:
Encoding speed for 1000 batches of 32 test strings:
Checklist
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example