-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Onnx fix test #10663
Onnx fix test #10663
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing!
@@ -38,19 +38,23 @@ def forward(self, input_ids, some_other_args, token_type_ids, attention_mask): | |||
|
|||
|
|||
class OnnxExportTestCase(unittest.TestCase): | |||
MODEL_TO_TEST = ["bert-base-cased", "gpt2", "roberta-base"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing roberta-base
is on purpose here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, speeding up things a bit, Roberta and Bert share the exact same graph, so basically it's testing the same things twice.
Merging now to rebase the slow tests and re-run them. |
* Allow to pass kwargs to model's from_pretrained when using pipeline. * Disable the use of past_keys_values for GPT2 when exporting to ONNX. * style * Remove comment. * Appease the documentation gods * Fix style Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
GPT2
past_keys_values
format seems to have changed since last time I checked, now exporting for each layer tuple with 2 elements.PyTorch's ONNX exporter doesn't seem to handle this format, so it was crashing with an error.
The PR assumes we don't currently support exporting
past_keys_values
for GPT2 and then disable the return of such values when constructing the model.In order to support this behavior,
pipeline()
now ha amodel_kwargs: Dict[str, Any]
parameter which forwards the dict of parameters to model'sfrom_pretrained(..., **model_kwargs)
.