-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FX tracing improvement #14321
FX tracing improvement #14321
Conversation
1f08935
to
474aa54
Compare
Hey, thanks for your PR @michaelbenayoun ! It seems there are a few failing tests (1096 😄), could you take a look at it? |
Currently looking into it! |
Fixed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not too comfortable with some of the changes in the models, especially XLNet, apart from that, the PR looks good.
In the tests, the fx_ready_model_classes
seems to always be set to all_model_classes
, so maybe it's time to use a boolean flag instead of a list of classes, if we always test all classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't had the time to look in depth. I'll review more when I'll have some more bandwidth
seq_ids = torch.arange(seq_length, device=device) | ||
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seq_ids = torch.arange(seq_length, device=device) | |
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] | |
causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length, dtype=torch.bool, device=device)) |
Unrelated to this PR, but constructing a triangular matrix should be a bit more simple IMO (unless I'm missing something) ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if we keep the code as is for now to make sure to not break anything here accidentally. Could you also run T5's and Bart's SLOW tests to be sure nothing is broken with the attention mask?
if self.cache is not None: | ||
return self.cache == other | ||
elif isinstance(other, HFProxy): | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be true no?
elif isinstance(self.cache, (torch.Size, list, tuple)): | ||
return len(self.cache) | ||
else: | ||
return super().__len__(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return super().__len__(self) | |
return super().__len__(self.cache) |
Shouldn't that be something along these lines?
def __len__(self): | ||
if self.cache is not None: | ||
if isinstance(self.cache, int): | ||
return self.cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why this is here?
|
||
def _function_to_leaf(func: Callable[..., Any]) -> Callable[..., Any]: | ||
"""Wrapper that marks func as a leaf function, meaning that it will not be traced through by HFTracer.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure to understand how that does what it says it does?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
Unstale comment |
… for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).
5d694b0
to
83aedfc
Compare
I am planning to try another approach to make both the code easier, and the tracing process cleaner, this will allow to add other models as well as to limit the number of bugs. |
@@ -1189,6 +1189,13 @@ def create_new_model_like( | |||
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) | |||
] | |||
|
|||
def disable_fx_test(filename: Path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of this @sgugger ?
The reason I added that is because symbolic_trace
checks the model class before trying to trace the model to make sure it is supported.
Because the tests are copied, if a new model is created from a supported model for symbolic tracing, the test file will contain something like fx_ready = True
which will trigger the torch.fx tests, all of them failing because the model class is not in the list of the supported models.
I do not think this is a good approach to automatically add the new model class to the supported models because the model implementation can be changed, so I thought that disabling the test and printing some message was a better option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me!
with open(filename) as fp: | ||
content = fp.read() | ||
with open(filename, "w") as fp: | ||
new_content = re.sub(r"fx_ready\s*=\s*True", "fx_ready = False", content) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, this line should go before the second with
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me as long as it's 100% backwards compatible.
Pinging @patrickvonplaten and @patil-suraj for a quick look as it touches to a lot of different models.
TORCH_FX_REQUIRED_VERSION = version.parse("1.9") | ||
TORCH_FX_REQUIRED_VERSION = version.parse("1.10") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is it possible to support many different versions, or are there breaking changes in torch.fx that we have to support one version at a time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can check for torch 1.9, the plan from now on is to support torch 1.10 + as fx became stable starting at this version (still need to validate that with pytorch team).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, sounds good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And you probably need to change this line from ==
to >=
.
print( | ||
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works " | ||
"for your new model." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally this would use the logger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed what was done in the script, but can definitely change that to logger if needed.
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) | ||
attn_weights = attn_weights / (value.size(-1) ** 0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this backwards compatible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, this doesn't cause any problems.
When we do tracing, python values cause several problems.
I don't think there is any reason to change this value to a Python value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change seems to cause the fail on mixed-precision training gpt-2 with ONNX Runtime backend. Link to the reported issue #11279.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went through all the modeling changes and it looks good to me!
@@ -1410,7 +1410,7 @@ def forward( | |||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[range(batch_size), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pooled_logits = logits[torch.arange(batch_size), sequence_lengths] | |
pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths] |
We need to make sure the tensor is on the same device no?
@@ -945,7 +945,7 @@ def forward( | |||
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" | |||
) | |||
|
|||
pooled_logits = logits[range(batch_size), sequence_lengths] | |||
pooled_logits = logits[torch.arange(batch_size), sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
tests/test_modeling_bert.py
Outdated
@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): | |||
else () | |||
) | |||
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () | |||
fx_ready_model_classes = all_model_classes | |||
fx_dynamic_ready_model_classes = all_model_classes | |||
fx_ready = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) not a huge fan of the name fx_ready
- does that mean fx_compatible
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, but in general this looks good to me as well
* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
* Change the way tracing happens, enabling dynamic axes out of the box * Update the tests and modeling xlnet * Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors). * Comments and making tracing work for gpt-j and xlnet * Refactore things related to num_choices (and batch_size, sequence_length) * Update fx to work on PyTorch 1.10 * Postpone autowrap_function feature usage for later * Add copyrights * Remove unnecessary file * Fix issue with add_new_model_like * Apply suggestions
What does this PR do?
This PR improves significantly the way transformers models are traced by the HFTracer (
torch.fx
).This has 2 major consequences:
Because of these changes the
symbolic_trace
signature becomes easier:symbolic_trace(model: PreTrainedModel, input_names: Optional[List[str]] = None) -> GraphModule
There is no need to specify the batch size, the sequence length or the number of choices (for multiple-choice) anymore.
The same thing can be said about the
HFTracer
, which can be instantiated exactly the same way as the regulartorch.fx.Tracer
can.