-
Notifications
You must be signed in to change notification settings - Fork 454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ORTModelForVision2Seq for VisionEncoderDecoder models inference #742
Add ORTModelForVision2Seq for VisionEncoderDecoder models inference #742
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as long as the tests pass, thank you for the addition! Could you also make sure that the code snippet in the documentation is valid?
Indices of decoder input sequence tokens in the vocabulary of shape `(batch_size, decoder_sequence_length)`. | ||
encoder_outputs (`torch.FloatTensor`): | ||
The encoder `last_hidden_state` of shape `(batch_size, encoder_sequence_length, hidden_size)`. | ||
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)` | |
past_key_values (`tuple(tuple(torch.FloatTensor), *optional*)` |
We should not put *optional*
in Optimum following this discussion: https://huggingface.slack.com/archives/C02P0559X9S/p1669628754896019
Can rather precise what the default is? Although, for the type: Tuple[Tuple[torch.FloatTensor]]
(edit: this is copy paste, so could you edit it as well in the rest?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added default to None as in the function sig
def exclude_trocr_with_cache(params): | ||
if params[0] == "trocr" and params[1] == True: | ||
return None | ||
return params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my knowledge, why is this not supported? What will happen if an user tries to use ORTModelForVision2Seq
with trocr
with use_cache = True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model currently not output any past_key_values
when use_cache=True
. So need to figure out why. Probably something wrong in the modeling code.
| What will happen if an user tries to use ORTModelForVision2Seq with trocr with use_cache = True?
For this I have added an error message during export.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
What's the status for the pat key / values?
61dd641
to
994a61e
Compare
Past key values is not supported only for TrOCR, would look into this after Donut |
What does this PR do?
This PR enables the inference of the
VisionEncoderDecoder
models using ONNXRuntime. The PR addsORTModelForVision2Seq
for doing inference similar toAutoModelForVision2Seq
by changing just a few lines.Usage
Limitations
Donut
model not supportedTrOCR
model not supported withuse_cache=True
.Before submitting