Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX runtime error after export of Deberta v3 SequenceClassification model #18237

Closed
2 tasks done
iiLaurens opened this issue Jul 21, 2022 · 1 comment · Fixed by #18272
Closed
2 tasks done

ONNX runtime error after export of Deberta v3 SequenceClassification model #18237

iiLaurens opened this issue Jul 21, 2022 · 1 comment · Fixed by #18272
Labels

Comments

@iiLaurens
Copy link
Contributor

iiLaurens commented Jul 21, 2022

System Info

  • Transformers: 4.20.1.dev0 (master branch as of 2022-07-21)
  • Platform: Windows-10-10.0.19044-SP0
  • Python version: 3.8.13
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0+cu113
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No

Issue both occurs on a Linux notebook with GPU (databricks platform) and on windows without GPU.

Do note that I use the latest development version of transformers, i.e. the current master branch of this repo. This is necessary because there are changes to symbolic ops in the Deberta V3 model that have not made it into a stable release yet.

Who can help?

@LysandreJik

Information

  • My own modified scripts

Tasks

  • My own task or dataset (give details below)

Reproduction

I am trying to make an ONNX export of a fine-tuned Deberta sequence classification model. Below are the steps to make such a model and export it to ONNX.

  1. First initiate a deberta sequence model. This example will just use the random weights, as there is no need for actual fine-tuning in this minimal example
  2. Export to onnx
  3. Test an inference using onnxruntime
from pathlib import Path

from onnxruntime import InferenceSession
from transformers.models.deberta_v2 import DebertaV2OnnxConfig
from transformers.onnx import export

from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification

# Step 1
model_base = 'microsoft/deberta-v3-xsmall'
config = AutoConfig.from_pretrained(model_base)
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(model_base)

# Step 2
onnx_path = Path(f"deberta.onnx")
onnx_config = DebertaV2OnnxConfig(config, task="sequence-classification")

export(tokenizer, model, onnx_config, 15, onnx_path)

# Step 3
session = InferenceSession(onnx_path.as_posix())

inputs = tokenizer("Using DeBERTa with ONNX Runtime!", return_tensors="np", return_token_type_ids=False)
input_feed = {k: v.astype('int64') for k, v in inputs.items()}

outputs = session.run(output_names=['logits'], input_feed=input_feed)

I would expect outputs from the inference model. However the error I am getting is:

onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Expand node. Name:'Expand_674' Status Message: invalid expand shape

Expected behavior

Surprisingly, this model doesn't seem to work when the sequence length is anything else but 8. For example:

# Anything with a sequence length of 8 runs fine:
inputs = tokenizer(["Using Deberta V3!"], return_tensors="np", return_token_type_ids=False)
inputs1 = {k: v.astype('int64') for k, v in inputs.items()}
outputs = session.run(output_names=['logits'], input_feed=inputs1)

# Anything else doesnt:
inputs = tokenizer(["Using Deberta V3 with ONNX Runtime!"], return_tensors="np", return_token_type_ids=False)
inputs2 = {k: v.astype('int64') for k, v in inputs.items()}
outputs = session.run(output_names=['logits'], input_feed=inputs2)

# Multiples of 8 will also not work:
inputs = tokenizer(["Hello world. This is me. I will crash this model now!"], return_tensors="np", return_token_type_ids=False)
inputs3 = {k: v.astype('int64') for k, v in inputs.items()}
outputs = session.run(output_names=['logits'], input_feed=inputs3)

I was wondering if it maybe has anything to do with the dynamic axes. However when I check the graph, it seems correct:

import onnx
m = onnx.load(str(onnx_path))
print(m.graph.input)
[name: "input_ids"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch"
      }
      dim {
        dim_param: "sequence"
      }
    }
  }
}
, name: "attention_mask"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_param: "batch"
      }
      dim {
        dim_param: "sequence"
      }
    }
  }
}
]
@iiLaurens iiLaurens added the bug label Jul 21, 2022
@iiLaurens iiLaurens changed the title ONNX runtime error after export of DebertaV2ForSequenceClassification model ONNX runtime error after export of Deberta v3 SequenceClassification model Jul 21, 2022
@JingyaHuang
Copy link
Contributor

Hi @iiLaurens, thanks for the PR on fixing the export of DeBERTa!

In terms of your use case, another possibility to simplify all the code would be using the optimum library which is an extension of transformers. You can use directly ORTModels and the pipeline for inference which are natively integrated with transformers.

Here is a snippet adapted to your case:

from optimum.onnxruntime.modeling_ort import ORTModelForSequenceClassification
from transformers import AutoTokenizer

ort_model = ORTModelForSequenceClassification.from_pretrained(model_id="results", file_name="deberta_v3_seq.onnx")
# Or download directly from the hub once your fix makes its way to the main of transformers
# ort_model = ORTModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-xsmall')
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-xsmall', use_fast=True)
inputs = tokenizer("Using DeBERTa with ONNX Runtime!", return_tensors="pt", return_token_type_ids=False)
pred = ort_model(**inputs)
>>> pred
SequenceClassifierOutput(loss=None, logits=tensor([[-0.0199,  0.1397]]), hidden_states=None, attentions=None)

Besides, you can also leverage other tools in optimum(graph optimization, quantization...) for accelerating your inference.

Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants