Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traced models serialization and torchscripting fix #17206

Merged
merged 16 commits into from
May 23, 2022

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented May 12, 2022

What does this PR do?

  • Fixes the issue that was preventing traced models to be TorchScripted
  • Fixes the issue that was preventing trace models serialization
  • Fixes get_attr issues

Fixes #15974

@jamesr66a Can you try on your end and validate that it solves your issues?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 12, 2022

The documentation is not available anymore as the PR was closed or merged.

@michaelbenayoun michaelbenayoun requested a review from sgugger May 13, 2022 09:36
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes!

Copy link
Contributor

@jamesr66a jamesr66a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solves our issue, thanks!

@LysandreJik
Copy link
Member

There seems to remain a few failures in torch.fx tests; the PR can be merged after those are solved!

@jamesr66a
Copy link
Contributor

jamesr66a commented May 17, 2022

@michaelbenayoun I'd like to propose some additional fixes that we discovered were needed to properly trace T5ForConditionalGeneration:

jamesr66a@1a75148

Can these be added?

@michaelbenayoun
Copy link
Member Author

Yes, for some reason the tests do not pass for torch 1.11 (I tested locally on torch 1.10).
I will add those changes too.

@jamesr66a
Copy link
Contributor

@michaelbenayoun Can I propose one final change to switch the graph surgery workaround to only trigger on older PyTorch versions where it's relevant?

jamesr66a@5ac7bb7

Otherwise, when we're working on PyTorch nightly, this^ code breaks because it's trying to remove nodes that still have uses

@michaelbenayoun
Copy link
Member Author

@jamesr66a I added the gating, but only from version 1.12 as it was failing otherwise.

@@ -187,7 +187,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for copy consistency.

@@ -211,7 +211,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, -float("inf")) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to be able to TorchScript the traced model.

@@ -198,7 +198,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to be able to TorchScript the traced model, this should not break things because it is equivalent, from the docs:

self.bool() is equivalent to self.to(torch.bool). See to().

@@ -1410,7 +1410,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason.
This should not break things because the tensor should be on the same device as logits anyway, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and better not rely on self.device anyway for model parallelism (I've made a few PRs to hunt most of those).

@@ -188,7 +188,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as for GPT-2

@@ -147,8 +147,8 @@ def __init__(self, config, attention_type):
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9))

self.attn_dropout = nn.Dropout(config.attention_dropout)
self.resid_dropout = nn.Dropout(config.resid_dropout)
self.attn_dropout = nn.Dropout(float(config.attention_dropout))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchScripting fails otherwise, this should not change anything.

@@ -69,7 +69,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1)
x = torch.stack((-x2, x1), dim=-1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed for TorchScript. This should be ok since stack can take the dim argument since the very beginning.

@@ -163,7 +163,7 @@ def _attn(

# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing as GPT-2.

@@ -971,7 +971,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing as GPT-2.

@michaelbenayoun michaelbenayoun requested a review from sgugger May 18, 2022 10:06
@@ -1410,7 +1410,7 @@ def forward(
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and better not rely on self.device anyway for model parallelism (I've made a few PRs to hunt most of those).

@@ -126,45 +128,45 @@ def _generate_supported_model_classes(
)


def embedding_override(self, input):
def torch_nn_embedding(self, input):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why rename all of those? I liked the fact it was clear they were "fake" functions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it was the case for all of them, I think that is clear enough since they are mapped in the _MANUAL_META_OVERRIDES dictionary

@@ -133,6 +134,7 @@ class ModelTesterMixin:
all_model_classes = ()
all_generative_model_classes = ()
fx_compatible = False
fx_trace_can_be_torchscripted = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No to the new flag though. I'd rather reuse the fx_compatible flag.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's another one. Basically all models but Swin can be torch scripted for now, so I use this flag to enable / disable the torchscript part in the torch_fx test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see it's a new flag. That's what I would like to avoid.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then what do you suggest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not adding a new flag.

Copy link
Collaborator

@sgugger sgugger May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for Swin, so you can override the new test and skip the part inside the swin testing file. My concern is that this gives several flags to learn for every contributor of new models, and they won't know the difference between all of them (right now we already have two with fx and torchscript, that's why I don't want a third).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how it's bad?

I am testing that traced models can be torchscripted here, without the flag I cannot really know when this should be tested.
Instead of using the flag, I could hardcode the classes to ignore but that seems worse IMO.

@jamesr66a
Copy link
Contributor

@michaelbenayoun Unfortunately, bumping the version check up to 1.12 breaks us. Actually, that was indirectly working around a semantic issue with deleting the concrete arg node. Do you mind augmenting the patch with this:

pbelevich@e3fce52

@michaelbenayoun
Copy link
Member Author

@sgugger Replaced the creation of a new flag by setting a special value for the fx_compatible flag for models that can be traced but not torchscipted (-1).
This flag should take a boolean value 99% of the time anyways.

@sgugger
Copy link
Collaborator

sgugger commented May 19, 2022

@michaelbenayoun This doesn't really work either. I'm not trying to be gratuitously painful here, but the common model tester is at the core of our test suite for the new model addition PRs. Those PRs are huge, and it only thanks to a robust CI that we can make sure the models added actually work with the whole API Transformers offers.

Adding a new flag, or a magic value for an existing flag, just because there is one model that needs different testing is not something we usually do or allow. In both cases, either the contributor or the reviewer will have no idea what your new flag/magic value does, especially since there is no documentation of it anywhere.

As I said before, in those instances where we need to adapt a common test to a specific model, we override it in the tester of said model. cc @LysandreJik and @patrickvonplaten

@michaelbenayoun michaelbenayoun requested a review from sgugger May 23, 2022 09:27
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good on the test side now, thanks a lot!

@michaelbenayoun michaelbenayoun merged commit 2e7e428 into huggingface:main May 23, 2022
@michaelbenayoun michaelbenayoun deleted the fx_issues branch May 23, 2022 15:50
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* Fix torch.jit.script and pickling issues

* Fix get_attr issues

* Fix import in function

* Fix GPT-J and T5 tracing for torch=1.11

* Gate graph surgery on torch version

* Modeling minor changes to enable TorchScripting

* Model serialization / deserialization test

* Remove _assert_is_none users
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Models traced with HFTracer cannot be TorchScripted or serialized
5 participants