Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss of accuracy when Longformer for SequenceClassification model is exported to ONNX #776

Closed
2 of 4 tasks
SteffenHaeussler opened this issue Feb 14, 2023 · 8 comments
Closed
2 of 4 tasks

Comments

@SteffenHaeussler
Copy link

SteffenHaeussler commented Feb 14, 2023

Edit: This is a crosspost to pytorch #94810. I don't know, where the issue lies.

System info

  • transformers version: 4.26.1
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.9.12
  • PyTorch version (GPU?): 1.13.0 (False)
  • onnx: 1.13.0
  • onnxruntime: 1.13.1

Who can help?

I think
@younesbelkada
would be a great help :)

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

This model is trained on client data and I'm not allowed to share the data or the weights, which makes any reproduction of this issue much harder. Please let me know when you need more information.

Here is the code snippet for the onnx conversion:

I follow this tutorial, but I also tried your tutorial. The onnx conversion with optimum is not available for Longformer so far and I haven't figured out yet, how to add it.

conversion:

import numpy as np
from onnxruntime import InferenceSession
from tqdm.auto import tqdm
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("deployment/best_model/")
model = AutoModelForSequenceClassification.from_pretrained("deployment/best_model/")

model.to("cpu")
model.eval()

example_input = tokenizer(
    dataset["test"]["text"][0], max_length=512, truncation=True, return_tensors="pt"
)
_ = model(**example_input)

torch.onnx.export(
    model,
    tuple(example_input.values()),
    f="model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence"},
        "attention_mask": {0: "batch_size", 1: "sequence"},
        "logits": {0: "batch_size", 1: "sequence"},
    },
    do_constant_folding=True,
    opset_version=16,
)

Calculating the accuracy:

session = InferenceSession("deployment/model.onnx", providers=["CPUExecutionProvider"])

y_hat_torch = []
y_hat_onnx = []

for text in dataset["test"]["text"]:
    tok_text = tokenizer(
        text, padding="max_length", max_length=512, truncation=True, return_tensors="np"
    )
    pred = session.run(None, input_feed=dict(tok_text))
    pred = np.argsort(pred[0][0])[::-1][0]
    y_hat_onnx.append(int(pred))

    tok_text = tokenizer(
        text, padding="max_length", max_length=512, truncation=True, return_tensors="pt"
    )
    pred = model(**tok_text)
    pred = torch.argsort(pred[0][0], descending=True)[0].numpy()
    y_hat_torch.append(int(pred))

print(
    f"Accuracy onnx:{sum([int(i)== int(j) for I, j in zip(y_hat_onnx, dataset['test']['label'])]) / len(y_hat_onnx):.2f}"
)
print(
    f"Accuracy torch:{sum([int(i)== int(j) for I, j in zip(y_hat_torch, dataset['test']['label'])]) / len(y_hat_torch):.2f}"
)

I also looked into the models' weights and the weights for the attention layer differ between torch and onnx. Here is an example:

import torch
import onnx
from onnx import numpy_helper

import numpy as np
from numpy.testing import assert_almost_equal

from transformers import AutoTokenizer, AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("deployment/best_model/")
onnx_model = onnx.load("deployment/model.onnx")

graph = onnx_model.graph

initalizers = dict()
for init in graph.initializer:
    initalizers[init.name] = numpy_helper.to_array(init).astype(np.float16)

model_init = dict()
for name, p in model.named_parameters():
    model_init[name] = p.detach().numpy().astype(np.float16)

assert len(initalizers) == len(model_init.keys()) # 53 layers

assert_almost_equal(initalizers['longformer.embeddings.word_embeddings.weight'], 
                    model_init['longformer.embeddings.word_embeddings.weight'], decimal=5)

assert_almost_equal(initalizers['classifier.dense.weight'], 
                    model_init['classifier.dense.weight'], decimal=5)

For the layer longformer.encoder.layer.0.output.dense.weight, which aligns with onnx::MatMul_6692 in shape and position:

assert_almost_equal(initalizers['onnx::MatMul_6692'], 
                    model_init['longformer.encoder.layer.0.output.dense.weight'], decimal=4)

I get

AssertionError: 
Arrays are not almost equal to 4 decimals

Mismatched elements: 2356293 / 2359296 (99.9%)
Max absolute difference: 1.776
Max relative difference: inf
 x: array([[ 0.0106,  0.1076,  0.0801, ...,  0.0425,  0.1548,  0.0123],
       [-0.0399, -0.1415,  0.0916, ...,  0.0181, -0.1277, -0.1335],
       [-0.0961,  0.0013,  0.0558, ..., -0.1354, -0.0965,  0.0447],...
 y: array([[-0.0699,  0.0743,  0.0339, ...,  0.0564, -0.087 ,  0.0649],
       [-0.1315, -0.0967, -0.045 , ..., -0.0492,  0.0775,  0.0284],
       [-0.1094,  0.0364,  0.1263, ..., -0.0308, -0.0118,  0.1523],...

Model config:

{
  "_name_or_path": "/datadrive/model/onnx/",
  "architectures": [
    "LongformerForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "attention_window": [
    512,
    512
  ],
  "bos_token_id": 1,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 514,
  "model_type": "longformer",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "onnx_export": false,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "problem_type": "single_label_classification",
  "sep_token_id": 2,
  "torch_dtype": "float32",
  "transformers_version": "4.26.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 32768
}

Expected behavior

I would expect a similar accuracy for both models:

Accuracy onnx: 17 %
Accuracy torch: 70 %

on test data with 3800 samples.

I would like to know what went wrong, how I can fix it, or who can help me. I'm clueless at the moment.
Alternatively I can also move to BigBird architecture since it has already some implementation on optimum.

I trained a small Longformer language model from scratch and fine-tuned it with custom data on a Sequence classification head. I used fp16 for training. The training run on a gpu.

@michaelbenayoun
Copy link
Member

Hi,
I think @fxmarty knows the origin of this issue, which most likely comes from the sequence length used if I recall properly.
This thread might also be helpful.

Also, note that the optimum library is now the main place to look and use for anything related to ONNX (export and inference).

@younesbelkada younesbelkada transferred this issue from huggingface/transformers Feb 14, 2023
@younesbelkada
Copy link
Contributor

Great! transferred the issue into optimum

@fxmarty
Copy link
Contributor

fxmarty commented Feb 14, 2023

Thank you, you can refer to the issue linked by @michaelbenayoun , the bottom-line issue is a bug in PyTorch: pytorch/pytorch#90607

As a dirty fix, you can pass the argument --sequence_length in the ONNX export to specify a longer sequence length to use during the ONNX export, so that the stride argument of as_strided() is the same during the ONNX export and inference.

The issue is solely related to sequence length. Thus, I would recommend you to check which sequence lengths yield meaningful outputs, as I did here.

@fxmarty
Copy link
Contributor

fxmarty commented Feb 14, 2023

@SteffenHaeussler Actually, I realize we disabled the ONNX export for longformer due to this bug.

@SteffenHaeussler
Copy link
Author

SteffenHaeussler commented Feb 15, 2023

@fxmarty Thanks a lot for your help. So the problem lies in torchs onnx converter and is out of our hands.

I tried some cheap tricks (torch->torchscript->onnx conversion), but obviously it doesn't work.
I followed the tutorial from microsoft onnxruntime, but with less success.
Also, different other ways failed.

For the moment, I copied the trained weights to a BertModel with similar architecture and fine-tuned the new model with my dataset. It looks promising and I will share the code snippet when I'm done - so others in the same situation can hopefully profit from this situation.

I don't see at the moment, how I can support you with this bug. Let me know, if there is anything to be done.

@fxmarty
Copy link
Contributor

fxmarty commented Feb 15, 2023

One dirty solution would be to rewrite this operation https://github.com/huggingface/transformers/blob/762dda44deed29baab049aac5324b49f134e7536/src/transformers/models/longformer/modeling_longformer.py#L924 in a way that is rightfully handled by the ONNX export. At the time when I had a look, I did not come up with an elegant solution though (i.e. not using nested loops).

@SteffenHaeussler
Copy link
Author

great. 👍

I can't promise anything, since my work schedule is overloaded, like for anyone else. But I will have a deeper look at it.

@fxmarty
Copy link
Contributor

fxmarty commented Feb 17, 2023

I will close for now as longformer is for now not supported in the ONNX export of Optimum for this exact reason.

@fxmarty fxmarty closed this as completed Feb 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants