-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPTDoubleHeadsModel Unexpected node type: onnx:: Sub #4950
Comments
Hi @mihail911, thanks for reporting the issue and the script to reproduce. I can confirm the issue, as it seems to happen on PyTorch side, I suspect it's a bug on their side. @tianleiwu should we forward the issue on PyTorch issue tracker? Slightly updated the script to avoid errors: import torch
from transformers import (GPT2Config, GPT2Model, GPT2Tokenizer, GPT2DoubleHeadsModel)
# use_cache is True by default in GPT2Model. Here we wrap a class to disable past state output.
class GPT2DoubleHeadsModelNoPastState(GPT2DoubleHeadsModel):
def __init__(self, config):
super().__init__(config)
def forward(self, input_ids, token_type_ids):
return super().forward(input_ids, past=None, attention_mask=None, token_type_ids=token_type_ids, use_cache=False)
model_name="gpt2"
config = GPT2Config.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2DoubleHeadsModelNoPastState.from_pretrained(model_name)
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
del example_inputs["attention_mask"]
example_outputs = model(**example_inputs)
input_names = ['input_ids', 'token_type_ids']
output_names=['output_1', 'output_2']
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'},
'token_type_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'},
'output_1': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len', 3: 'vocab_size'},
'output_2': {0: 'batch_size', 1: 'num_choices'}
}
output_path='gpt2.onnx'
torch.onnx.export(model=model,
args=(example_inputs[input_names[0]].unsqueeze(0), example_inputs[input_names[1]].unsqueeze(0)),
f=output_path,
input_names=input_names,
output_names=output_names,
example_outputs=example_outputs,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=11,
use_external_data_format=False) |
@mfuntowicz, I've forwarded the issue to the developer of pytorch onnx exporter. I did narrow down the issue to one line. A walk-around is to add int() to cast data type: Before:
After:
@mihail911, could you try this (need install transformers from source) to see whether you can export the model? |
Yes I am able to export the model. Thanks @tianleiwu @mfuntowicz! |
Summary: Fix export of full_like when fill_value is of type torch._C.Value. This PR fixes a bug when exporting GPT2DoubleHeadsModel huggingface/transformers#4950 Pull Request resolved: #40063 Reviewed By: hl475 Differential Revision: D22398353 Pulled By: houseroad fbshipit-source-id: 6980a61211fe571c2e4a57716970f474851d811e
Summary: Fix export of full_like when fill_value is of type torch._C.Value. This PR fixes a bug when exporting GPT2DoubleHeadsModel huggingface/transformers#4950 Pull Request resolved: pytorch#40063 Reviewed By: hl475 Differential Revision: D22398353 Pulled By: houseroad fbshipit-source-id: 6980a61211fe571c2e4a57716970f474851d811e
🐛 Bug
Information
Model I am using (Bert, XLNet ...): GPT2DoubleHeadsModel
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
The tasks I am working on is:
To reproduce
Steps to reproduce the behavior:
I've been following the ipython notebook provided here
gpt
model and export to onnx format using the following script:This script is based off of #4805
After invoking the above, I get the error:
Expected behavior
I would expect this to work successfully, and unfortunately I'm not exactly sure how to interpret this error. There's not a lot of documentation online.
Environment info
transformers
version: Commit 0e1869cThanks for your help! @mfuntowicz @tianleiwu
The text was updated successfully, but these errors were encountered: