Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FX tracing improvement #14321

Merged

Conversation

michaelbenayoun
Copy link
Member

@michaelbenayoun michaelbenayoun commented Nov 8, 2021

What does this PR do?

This PR improves significantly the way transformers models are traced by the HFTracer (torch.fx).
This has 2 major consequences:

  • More model architectures can be supported
  • When a model can be traced, the resulting GraphModule can take any input shapes out of the box (compared to what was done before where a lot of work was needed to enable dynamic axes for a given model), this is both easier and less bug prone.

Because of these changes the symbolic_trace signature becomes easier:
symbolic_trace(model: PreTrainedModel, input_names: Optional[List[str]] = None) -> GraphModule

There is no need to specify the batch size, the sequence length or the number of choices (for multiple-choice) anymore.
The same thing can be said about the HFTracer, which can be instantiated exactly the same way as the regular torch.fx.Tracer can.

@michaelbenayoun michaelbenayoun marked this pull request as ready for review November 9, 2021 15:20
@michaelbenayoun michaelbenayoun changed the title Fx tracing enhancement FX tracing improvement Nov 9, 2021
@michaelbenayoun michaelbenayoun requested review from sgugger and LysandreJik and removed request for sgugger November 9, 2021 15:21
@LysandreJik
Copy link
Member

Hey, thanks for your PR @michaelbenayoun ! It seems there are a few failing tests (1096 😄), could you take a look at it?

@michaelbenayoun
Copy link
Member Author

Currently looking into it!
Sorry about that.

@michaelbenayoun
Copy link
Member Author

Fixed!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not too comfortable with some of the changes in the models, especially XLNet, apart from that, the PR looks good.

In the tests, the fx_ready_model_classes seems to always be set to all_model_classes, so maybe it's time to use a boolean flag instead of a list of classes, if we always test all classes?

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/models/xlnet/modeling_xlnet.py Outdated Show resolved Hide resolved
tests/test_modeling_bart.py Outdated Show resolved Hide resolved
tests/test_modeling_layoutlm.py Outdated Show resolved Hide resolved
Copy link
Contributor

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't had the time to look in depth. I'll review more when I'll have some more bandwidth

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
Comment on lines +247 to +246
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length, dtype=torch.bool, device=device))

Unrelated to this PR, but constructing a triangular matrix should be a bit more simple IMO (unless I'm missing something) ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice if we keep the code as is for now to make sure to not break anything here accidentally. Could you also run T5's and Bart's SLOW tests to be sure nothing is broken with the attention mask?

src/transformers/models/xlnet/modeling_xlnet.py Outdated Show resolved Hide resolved
if self.cache is not None:
return self.cache == other
elif isinstance(other, HFProxy):
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be true no?

elif isinstance(self.cache, (torch.Size, list, tuple)):
return len(self.cache)
else:
return super().__len__(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return super().__len__(self)
return super().__len__(self.cache)

Shouldn't that be something along these lines?

def __len__(self):
if self.cache is not None:
if isinstance(self.cache, int):
return self.cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is here?


def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]:
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand how that does what it says it does?

@github-actions
Copy link

github-actions bot commented Dec 8, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@michaelbenayoun
Copy link
Member Author

Unstale comment

@michaelbenayoun
Copy link
Member Author

michaelbenayoun commented Jan 10, 2022

I am planning to try another approach to make both the code easier, and the tracing process cleaner, this will allow to add other models as well as to limit the number of bugs.
In the mean time, I think this can be merged because a few issues were posted to have symbolic_trace working for Pytorch 1.10, which this PR enables.

@@ -1189,6 +1189,13 @@ def create_new_model_like(
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
]

def disable_fx_test(filename: Path):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of this @sgugger ?

The reason I added that is because symbolic_trace checks the model class before trying to trace the model to make sure it is supported.
Because the tests are copied, if a new model is created from a supported model for symbolic tracing, the test file will contain something like fx_ready = True which will trigger the torch.fx tests, all of them failing because the model class is not in the list of the supported models.
I do not think this is a good approach to automatically add the new model class to the supported models because the model implementation can be changed, so I thought that disabling the test and printing some message was a better option.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me!

with open(filename) as fp:
content = fp.read()
with open(filename, "w") as fp:
new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, this line should go before the second with.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me as long as it's 100% backwards compatible.

Pinging @patrickvonplaten and @patil-suraj for a quick look as it touches to a lot of different models.

TORCH_FX_REQUIRED_VERSION = version.parse("1.9")
TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it possible to support many different versions, or are there breaking changes in torch.fx that we have to support one version at a time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can check for torch 1.9, the plan from now on is to support torch 1.10 + as fx became stable starting at this version (still need to validate that with pytorch team).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, sounds good to me

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And you probably need to change this line from == to >=.

Comment on lines +1217 to +1220
print(
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works "
"for your new model."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would use the logger

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed what was done in the script, but can definitely change that to logger if needed.

attn_weights = attn_weights / (float(value.size(-1)) ** 0.5)
attn_weights = attn_weights / (value.size(-1) ** 0.5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backwards compatible?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, this doesn't cause any problems.

When we do tracing, python values cause several problems.
I don't think there is any reason to change this value to a Python value.

Copy link
Contributor

@JingyaHuang JingyaHuang Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems to cause the fail on mixed-precision training gpt-2 with ONNX Runtime backend. Link to the reported issue #11279.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went through all the modeling changes and it looks good to me!

@@ -1410,7 +1410,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]

We need to make sure the tensor is on the same device no?

@@ -945,7 +945,7 @@ def forward(
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[range(batch_size), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size), sequence_lengths]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else ()
)
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
fx_ready = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) not a huge fan of the name fx_ready - does that mean fx_compatible?

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, but in general this looks good to me as well

@michaelbenayoun michaelbenayoun merged commit 0fe17f3 into huggingface:master Feb 7, 2022
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request Feb 18, 2022
* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
ManuelFay pushed a commit to ManuelFay/transformers that referenced this pull request Mar 31, 2022
* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants