Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ORTModel support for custom tasks #303

Merged
merged 69 commits into from
Aug 3, 2022
Merged

Conversation

JingyaHuang
Copy link
Collaborator

@JingyaHuang JingyaHuang commented Jul 18, 2022

What does this PR do?

ORTModelForXXX.model decided valid inputs and outputs of the model's forward method, thus the creation of inputs in the forward method can be abstract, and also the outputs. This would allow the ORTModels to be more flexible.

e.g. In ORTTrainer, the evaluation includes labels as input and loss as output. With the PR, it will enable us to replace bare inference sessions with ORTModels more easily.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 18, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @JingyaHuang!!
Just one nit

optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
Copy link
Member

@philschmid philschmid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we should make this change for enabling something, which the classes aren't designed for.
You said

e.g. In ORTTrainer, the evaluation includes labels as input and loss as output. With the PR, it will enable us to replace bare inference sessions with ORTModels more easily.

This sounds more that we should use evaluate and pipeline or the ORTModel in the trainer with the post-processing outside of the model.

optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_ort.py Outdated Show resolved Hide resolved
@philschmid
Copy link
Member

philschmid commented Jul 18, 2022

I still do not understand the purpose of this change.

As mentioned before the ORTModelForXX were introduced for inference and not for training. The idea was to be able to have API compatible Model classes, which can be used with pipelines without the need to re-write the pre - & post-processing.
Also, the idea is to add those inference model classes to other packages, e.g optimum-intel.
Additionally, we are looking into removing the copying and adding IOBindings to reduce latency in the future.

The changes you suggest:

  • slow down the inference
  • add a lot of complex dynamic code -> which we tried to exclude, that's why we have several ORTModelForXX classes rather than one
  • Add support for something ORTTrainier training specific.

The question I have is:

  • Is this change needed for the ORTTrainer? how are we currently doing it?
  • What is the benefit for the customer?

Copy link
Member

@philschmid philschmid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test and then we should be good. Good idea! ✅ And if some use case emerges out of it we can add new task-specific model classes

optimum/onnxruntime/modeling_ort.py Show resolved Hide resolved
@JingyaHuang
Copy link
Collaborator Author

I still do not understand the purpose of this change.

As mentioned before the ORTModelForXX were introduced for inference and not for training. The idea was to be able to have API compatible Model classes, which can be used with pipelines without the need to re-write the pre - & post-processing. Also, the idea is to add those inference model classes to other packages, e.g optimum-intel. Additionally, we are looking into removing the copying and adding IOBindings to reduce latency in the future.

The changes you suggest:

  • slow down the inference
  • add a lot of complex dynamic code -> which we tried to exclude, that's why we have several ORTModelForXX classes rather than one
  • Add support for something ORTTrainier training specific.

The question I have is:

  • Is this change needed for the ORTTrainer? how are we currently doing it?
  • What is the benefit for the customer?

Hi @philschmid, sorry for the late reply. I re-drafted the code, indeed it shouldn't be in other task-specific models as it will slow them down. The basic idea behind the PR is to leave some flexibility to users, it's like a fallback so that when they are using a more customized model they can still be able to benefit from the ORTModel foundation with a small sacrifice of speed.
ORTTrainer.train() is independent of the PR, but for the inference of ORTTrainer the evaluate() and predict(), I am using directly InferenceSession right now and I will replace it with ORTModel, for the predict it is pretty straight forward, but for the evaluate the model include loss thus I need something more customized, and things like ORTModelForCustomTasks shall be helpful.

@JingyaHuang JingyaHuang changed the title Refactoring ort model inputs and outputs Add ORTModel support for custom tasks Jul 22, 2022
@JingyaHuang JingyaHuang changed the base branch from main to doc-builder-habana-test August 1, 2022 20:05
@JingyaHuang JingyaHuang changed the base branch from doc-builder-habana-test to main August 1, 2022 20:06
@JingyaHuang JingyaHuang merged commit d3c0b75 into main Aug 3, 2022
@JingyaHuang JingyaHuang deleted the jingya-refactoring-ort-model branch August 3, 2022 09:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants