Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTDoubleHeadsModel Unexpected node type: onnx:: Sub #4950

Closed
2 of 4 tasks
mihail911 opened this issue Jun 12, 2020 · 3 comments
Closed
2 of 4 tasks

GPTDoubleHeadsModel Unexpected node type: onnx:: Sub #4950

mihail911 opened this issue Jun 12, 2020 · 3 comments

Comments

@mihail911
Copy link

🐛 Bug

Information

Model I am using (Bert, XLNet ...): GPT2DoubleHeadsModel

Language I am using the model on (English, Chinese ...): English

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:
I've been following the ipython notebook provided here

  1. Take an off-the-shelf pretrained gpt model and export to onnx format using the following script:
import torch
from transformers import (GPT2Config, GPT2Model, GPT2Tokenizer)

# use_cache is True by default in GPT2Model. Here we wrap a class to disable past state output.
class GPT2DoubleHeadsModelNoPastState(GPT2DoubleHeadsModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids, token_type_ids):
        return super().forward(input_ids, past=None, attention_mask=None, token_type_ids=token_type_ids, use_cache=False)

model_name="gpt2"
config = GPT2Config.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2ModelNoPastState.from_pretrained(model_name)

example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
del example_inputs["attention_mask"]
example_outputs = model(**example_inputs)

input_names = ['input_ids',  'token_type_ids']
output_names=['output_1', 'output_2']
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'}, 
     'token_type_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'},
     'output_1': {0: 'batch_size', 1: 'num_choices': 2: 'seq_len', 3: 'vocab_size'}, 
     'output_2': {0: 'batch_size', 1: 'num_choices'}
}
output_path='gpt2.onnx'
torch.onnx.export(model=model,
                  args=(example_inputs[input_names[0]].unsqueeze(0), example_inputs[input_names[1]].unsqueeze(0)),
                  f=output_path,
                  input_names=input_names,
                  output_names=output_names,
                  example_outputs=example_outputs,
                  dynamic_axes=dynamic_axes,
                  do_constant_folding=True,
                  opset_version=11,
                  use_external_data_format=False)

This script is based off of #4805

After invoking the above, I get the error:

.../torch/onnx/symbolic_helper.py", line 87...
RuntimeError: Unexpected node type: onnx::Sub

Expected behavior

I would expect this to work successfully, and unfortunately I'm not exactly sure how to interpret this error. There's not a lot of documentation online.

Environment info

  • transformers version: Commit 0e1869c
  • Onnxruntime: 1.3.0
  • Python version: 3.6.10
  • PyTorch version (GPU?): 1.5.0+cu101
  • Using GPU in script?: Yes

Thanks for your help! @mfuntowicz @tianleiwu

@mfuntowicz
Copy link
Member

Hi @mihail911, thanks for reporting the issue and the script to reproduce.

I can confirm the issue, as it seems to happen on PyTorch side, I suspect it's a bug on their side. @tianleiwu should we forward the issue on PyTorch issue tracker?

Slightly updated the script to avoid errors:

import torch
from transformers import (GPT2Config, GPT2Model, GPT2Tokenizer, GPT2DoubleHeadsModel)

# use_cache is True by default in GPT2Model. Here we wrap a class to disable past state output.
class GPT2DoubleHeadsModelNoPastState(GPT2DoubleHeadsModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids, token_type_ids):
        return super().forward(input_ids, past=None, attention_mask=None, token_type_ids=token_type_ids, use_cache=False)

model_name="gpt2"
config = GPT2Config.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2DoubleHeadsModelNoPastState.from_pretrained(model_name)

example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
del example_inputs["attention_mask"]
example_outputs = model(**example_inputs)

input_names = ['input_ids',  'token_type_ids']
output_names=['output_1', 'output_2']
dynamic_axes={'input_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'}, 
     'token_type_ids': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len'},
     'output_1': {0: 'batch_size', 1: 'num_choices', 2: 'seq_len', 3: 'vocab_size'}, 
     'output_2': {0: 'batch_size', 1: 'num_choices'}
}
output_path='gpt2.onnx'
torch.onnx.export(model=model,
                  args=(example_inputs[input_names[0]].unsqueeze(0), example_inputs[input_names[1]].unsqueeze(0)),
                  f=output_path,
                  input_names=input_names,
                  output_names=output_names,
                  example_outputs=example_outputs,
                  dynamic_axes=dynamic_axes,
                  do_constant_folding=True,
                  opset_version=11,
                  use_external_data_format=False)

@tianleiwu
Copy link
Contributor

tianleiwu commented Jun 13, 2020

@mfuntowicz, I've forwarded the issue to the developer of pytorch onnx exporter.

I did narrow down the issue to one line. A walk-around is to add int() to cast data type:

Before:

cls_index = torch.full_like(hidden_states[..., :1, :], hidden_states.shape[-2] - 1, dtype=torch.long,)

After:

cls_index = torch.full_like(hidden_states[..., :1, :], int(hidden_states.shape[-2]) - 1, dtype=torch.long,)

@mihail911, could you try this (need install transformers from source) to see whether you can export the model?

@mihail911
Copy link
Author

Yes I am able to export the model. Thanks @tianleiwu @mfuntowicz!

facebook-github-bot pushed a commit to pytorch/pytorch that referenced this issue Jul 7, 2020
Summary:
Fix export of full_like when fill_value is of type torch._C.Value.

This PR fixes a bug when exporting GPT2DoubleHeadsModel huggingface/transformers#4950

Pull Request resolved: #40063

Reviewed By: hl475

Differential Revision: D22398353

Pulled By: houseroad

fbshipit-source-id: 6980a61211fe571c2e4a57716970f474851d811e
csarofeen pushed a commit to csarofeen/pytorch that referenced this issue Jul 7, 2020
Summary:
Fix export of full_like when fill_value is of type torch._C.Value.

This PR fixes a bug when exporting GPT2DoubleHeadsModel huggingface/transformers#4950

Pull Request resolved: pytorch#40063

Reviewed By: hl475

Differential Revision: D22398353

Pulled By: houseroad

fbshipit-source-id: 6980a61211fe571c2e4a57716970f474851d811e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants