You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Transformers: 4.20.1.dev0 (master branch as of 2022-07-21)
Platform: Windows-10-10.0.19044-SP0
Python version: 3.8.13
Huggingface_hub version: 0.8.1
PyTorch version (GPU?): 1.12.0+cu113
Tensorflow version (GPU?): not installed (NA)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Using GPU in script?: No
Issue both occurs on a Linux notebook with GPU (databricks platform) and on windows without GPU.
Do note that I use the latest development version of transformers, i.e. the current master branch of this repo. This is necessary because there are changes to symbolic ops in the Deberta V3 model that have not made it into a stable release yet.
I am trying to make an ONNX export of a fine-tuned Deberta sequence classification model. Below are the steps to make such a model and export it to ONNX.
First initiate a deberta sequence model. This example will just use the random weights, as there is no need for actual fine-tuning in this minimal example
I would expect outputs from the inference model. However the error I am getting is:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Expand node. Name:'Expand_674' Status Message: invalid expand shape
Expected behavior
Surprisingly, this model doesn't seem to work when the sequence length is anything else but 8. For example:
# Anything with a sequence length of 8 runs fine:inputs=tokenizer(["Using Deberta V3!"], return_tensors="np", return_token_type_ids=False)
inputs1= {k: v.astype('int64') fork, vininputs.items()}
outputs=session.run(output_names=['logits'], input_feed=inputs1)
# Anything else doesnt:inputs=tokenizer(["Using Deberta V3 with ONNX Runtime!"], return_tensors="np", return_token_type_ids=False)
inputs2= {k: v.astype('int64') fork, vininputs.items()}
outputs=session.run(output_names=['logits'], input_feed=inputs2)
# Multiples of 8 will also not work:inputs=tokenizer(["Hello world. This is me. I will crash this model now!"], return_tensors="np", return_token_type_ids=False)
inputs3= {k: v.astype('int64') fork, vininputs.items()}
outputs=session.run(output_names=['logits'], input_feed=inputs3)
I was wondering if it maybe has anything to do with the dynamic axes. However when I check the graph, it seems correct:
iiLaurens
changed the title
ONNX runtime error after export of DebertaV2ForSequenceClassification model
ONNX runtime error after export of Deberta v3 SequenceClassification model
Jul 21, 2022
Hi @iiLaurens, thanks for the PR on fixing the export of DeBERTa!
In terms of your use case, another possibility to simplify all the code would be using the optimum library which is an extension of transformers. You can use directly ORTModels and the pipeline for inference which are natively integrated with transformers.
Here is a snippet adapted to your case:
fromoptimum.onnxruntime.modeling_ortimportORTModelForSequenceClassificationfromtransformersimportAutoTokenizerort_model=ORTModelForSequenceClassification.from_pretrained(model_id="results", file_name="deberta_v3_seq.onnx")
# Or download directly from the hub once your fix makes its way to the main of transformers# ort_model = ORTModelForSequenceClassification.from_pretrained('microsoft/deberta-v3-xsmall')tokenizer=AutoTokenizer.from_pretrained('microsoft/deberta-v3-xsmall', use_fast=True)
inputs=tokenizer("Using DeBERTa with ONNX Runtime!", return_tensors="pt", return_token_type_ids=False)
pred=ort_model(**inputs)
>>> pred
SequenceClassifierOutput(loss=None, logits=tensor([[-0.0199, 0.1397]]), hidden_states=None, attentions=None)
Besides, you can also leverage other tools in optimum(graph optimization, quantization...) for accelerating your inference.
System Info
Issue both occurs on a Linux notebook with GPU (databricks platform) and on windows without GPU.
Do note that I use the latest development version of transformers, i.e. the current master branch of this repo. This is necessary because there are changes to symbolic ops in the Deberta V3 model that have not made it into a stable release yet.
Who can help?
@LysandreJik
Information
Tasks
Reproduction
I am trying to make an ONNX export of a fine-tuned Deberta sequence classification model. Below are the steps to make such a model and export it to ONNX.
onnxruntime
I would expect outputs from the inference model. However the error I am getting is:
Expected behavior
Surprisingly, this model doesn't seem to work when the sequence length is anything else but 8. For example:
I was wondering if it maybe has anything to do with the dynamic axes. However when I check the graph, it seems correct:
The text was updated successfully, but these errors were encountered: