-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Traced models serialization and torchscripting fix #17206
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solves our issue, thanks!
There seems to remain a few failures in |
@michaelbenayoun I'd like to propose some additional fixes that we discovered were needed to properly trace Can these be added? |
Yes, for some reason the tests do not pass for torch 1.11 (I tested locally on torch 1.10). |
373feb6
to
9309f27
Compare
@michaelbenayoun Can I propose one final change to switch the graph surgery workaround to only trigger on older PyTorch versions where it's relevant? Otherwise, when we're working on PyTorch nightly, this^ code breaks because it's trying to remove nodes that still have uses |
@jamesr66a I added the gating, but only from version 1.12 as it was failing otherwise. |
@@ -187,7 +187,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |||
if not self.is_cross_attention: | |||
# if only "normal" attention layer implements causal mask | |||
query_length, key_length = query.size(-2), key.size(-2) | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed for copy consistency.
@@ -211,7 +211,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: | |||
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) | |||
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) | |||
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) | |||
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length) | |||
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to be able to TorchScript the traced model.
@@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |||
if not self.is_cross_attention: | |||
# if only "normal" attention layer implements causal mask | |||
query_length, key_length = query.size(-2), key.size(-2) | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -1410,7 +1410,7 @@ def forward( | |||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reason.
This should not break things because the tensor should be on the same device as logits
anyway, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and better not rely on self.device
anyway for model parallelism (I've made a few PRs to hunt most of those).
@@ -188,7 +188,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) | |||
|
|||
query_length, key_length = query.size(-2), key.size(-2) | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as for GPT-2
@@ -147,8 +147,8 @@ def __init__(self, config, attention_type): | |||
self.register_buffer("bias", bias) | |||
self.register_buffer("masked_bias", torch.tensor(-1e9)) | |||
|
|||
self.attn_dropout = nn.Dropout(config.attention_dropout) | |||
self.resid_dropout = nn.Dropout(config.resid_dropout) | |||
self.attn_dropout = nn.Dropout(float(config.attention_dropout)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TorchScripting fails otherwise, this should not change anything.
@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None): | |||
def rotate_every_two(x): | |||
x1 = x[:, :, :, ::2] | |||
x2 = x[:, :, :, 1::2] | |||
x = torch.stack((-x2, x1), axis=-1) | |||
x = torch.stack((-x2, x1), dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed for TorchScript. This should be ok since stack can take the dim
argument since the very beginning.
@@ -163,7 +163,7 @@ def _attn( | |||
|
|||
# compute causal mask from causal mask buffer | |||
query_length, key_length = query.size(-2), key.size(-2) | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() | |||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as GPT-2.
@@ -971,7 +971,7 @@ def forward( | |||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing as GPT-2.
@@ -1410,7 +1410,7 @@ def forward( | |||
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and better not rely on self.device
anyway for model parallelism (I've made a few PRs to hunt most of those).
@@ -126,45 +128,45 @@ def _generate_supported_model_classes( | |||
) | |||
|
|||
|
|||
def embedding_override(self, input): | |||
def torch_nn_embedding(self, input): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why rename all of those? I liked the fact it was clear they were "fake" functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but it was the case for all of them, I think that is clear enough since they are mapped in the _MANUAL_META_OVERRIDES
dictionary
tests/test_modeling_common.py
Outdated
@@ -133,6 +134,7 @@ class ModelTesterMixin: | |||
all_model_classes = () | |||
all_generative_model_classes = () | |||
fx_compatible = False | |||
fx_trace_can_be_torchscripted = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No to the new flag though. I'd rather reuse the fx_compatible
flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's another one. Basically all models but Swin can be torch scripted for now, so I use this flag to enable / disable the torchscript part in the torch_fx test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I see it's a new flag. That's what I would like to avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then what do you suggest?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not adding a new flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for Swin, so you can override the new test and skip the part inside the swin testing file. My concern is that this gives several flags to learn for every contributor of new models, and they won't know the difference between all of them (right now we already have two with fx and torchscript, that's why I don't want a third).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how it's bad?
I am testing that traced models can be torchscripted here, without the flag I cannot really know when this should be tested.
Instead of using the flag, I could hardcode the classes to ignore but that seems worse IMO.
@michaelbenayoun Unfortunately, bumping the version check up to |
@sgugger Replaced the creation of a new flag by setting a special value for the |
@michaelbenayoun This doesn't really work either. I'm not trying to be gratuitously painful here, but the common model tester is at the core of our test suite for the new model addition PRs. Those PRs are huge, and it only thanks to a robust CI that we can make sure the models added actually work with the whole API Transformers offers. Adding a new flag, or a magic value for an existing flag, just because there is one model that needs different testing is not something we usually do or allow. In both cases, either the contributor or the reviewer will have no idea what your new flag/magic value does, especially since there is no documentation of it anywhere. As I said before, in those instances where we need to adapt a common test to a specific model, we override it in the tester of said model. cc @LysandreJik and @patrickvonplaten |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good on the test side now, thanks a lot!
* Fix torch.jit.script and pickling issues * Fix get_attr issues * Fix import in function * Fix GPT-J and T5 tracing for torch=1.11 * Gate graph surgery on torch version * Modeling minor changes to enable TorchScripting * Model serialization / deserialization test * Remove _assert_is_none users
What does this PR do?
Fixes #15974
@jamesr66a Can you try on your end and validate that it solves your issues?