Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ORTModelForVision2Seq for VisionEncoderDecoder models inference #742

Merged

Conversation

mht-sharma
Copy link
Contributor

@mht-sharma mht-sharma commented Feb 3, 2023

What does this PR do?

This PR enables the inference of the VisionEncoderDecoder models using ONNXRuntime. The PR adds ORTModelForVision2Seq for doing inference similar to AutoModelForVision2Seq by changing just a few lines.

Usage

>>> from PIL import Image
>>> from transformers import GPT2TokenizerFast, ViTImageProcessor
->>>  from transformers import AutoModelForVision2Seq
+>>>  from optimum.onnxruntime import ORTModelForVision2Seq
>>> import requests

>>> model_name = "nlpconnect/vit-gpt2-image-captioning"
->>> model = AutoModelForVision2Seq.from_pretrained(model_name)
+>>> model = ORTModelForVision2Seq.from_pretrained(model_name, from_transformers=True)
>>> tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
>>> image_processor = ViTImageProcessor.from_pretrained(model_name)

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = image_processor(image, return_tensors="pt").pixel_values

>>> generated_ids = model.generate(pixel_values)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)

Limitations

  • Donut model not supported
  • TrOCR model not supported with use_cache=True.
  • IObinding not supported

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Feb 3, 2023

The documentation is not available anymore as the PR was closed or merged.

@mht-sharma mht-sharma marked this pull request as ready for review February 3, 2023 18:57
@mht-sharma mht-sharma changed the title add ORTModelForVision2Seq Add ORTModelForVision2Seq for VisionEncoderDecoder models inference Feb 3, 2023
Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as the tests pass, thank you for the addition! Could you also make sure that the code snippet in the documentation is valid?

optimum/onnxruntime/modeling_seq2seq.py Show resolved Hide resolved
optimum/onnxruntime/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_seq2seq.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_seq2seq.py Show resolved Hide resolved
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`.
encoder_outputs (`torch.FloatTensor`):
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)`
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)`

We should not put *optional* in Optimum following this discussion: https://huggingface.slack.com/archives/C02P0559X9S/p1669628754896019

Can rather precise what the default is? Although, for the type: Tuple[Tuple[torch.FloatTensor]]

(edit: this is copy paste, so could you edit it as well in the rest?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added default to None as in the function sig

Comment on lines +2763 to +2766
def exclude_trocr_with_cache(params):
if params[0] == "trocr" and params[1] == True:
return None
return params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my knowledge, why is this not supported? What will happen if an user tries to use ORTModelForVision2Seq with trocr with use_cache = True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model currently not output any past_key_values when use_cache=True. So need to figure out why. Probably something wrong in the modeling code.

| What will happen if an user tries to use ORTModelForVision2Seq with trocr with use_cache = True?

For this I have added an error message during export.

optimum/utils/normalized_config.py Show resolved Hide resolved
Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

What's the status for the pat key / values?

optimum/onnxruntime/base.py Outdated Show resolved Hide resolved
@mht-sharma mht-sharma force-pushed the add_ort_support_vision_encoder_decoder branch from 61dd641 to 994a61e Compare February 6, 2023 13:54
@mht-sharma
Copy link
Contributor Author

LGTM!

What's the status for the pat key / values?

Past key values is not supported only for TrOCR, would look into this after Donut

@mht-sharma mht-sharma merged commit 69764f1 into huggingface:main Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants