Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ORTTrainer failure on DeBERTa(base/v2/sew_d) fp16 training #18529

Closed
wants to merge 82 commits into from

Conversation

JingyaHuang
Copy link
Contributor

What does this PR do?

Context

It was reported in optimum huggingface/optimum#305 that the training on DeBERTa with optimum.onnxruntime.ORTTrainer is broken.
After investigation, the break comes from two causes:

However with those two fixes, the fp32 training will work, but the mixed-precision training will fail due to mismatched inputs dtype for some Matmul nodes. In #18272, some sqrt results are cast to fp32, and they need to be re-casted to fp16 before Matmul ops, and this PR is supposed to add the re-cast part.

Fixes #huggingface/optimum#305

Who can review?

@LysandreJik @patrickvonplaten @lewtun

pocca2048 and others added 15 commits August 11, 2022 14:31
* fix typos

* fix sequence_length docs of LayoutLMv3Model

* delete trailing white spaces

* fix layoutlmv3 docs more

* apply make fixup & quality

* change to two versions of input docstring

* apply make fixup & quality
…upport Opacus training (huggingface#18486)

* changing BartLearnedPositionalEmbedding forward signature and references to it

* removing debugging dead code (thanks style checker)

* blackened modeling_bart file

* removing copy inconsistencies via make fix-copies

* changing references to copied signatures in Bart variants

* make fix-copies once more

* using expand over repeat (thanks @michaelbenayoun)

* expand instead of repeat for all model copies

Co-authored-by: Daniel Jones <jonesdaniel@microsoft.com>
* Create _config.py

* Create _toctree.yml

* Create index.mdx

not sure about "du / ihr" oder "sie"

* Create quicktour.mdx

* Update _toctree.yml

* Update build_documentation.yml

* Update build_pr_documentation.yml

* fix build

* Update index.mdx

* Update quicktour.mdx

* Create installation.mdx

* Update _toctree.yml
…face#18272)

* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
@JingyaHuang JingyaHuang changed the base branch from main to albertvillanova-patch-1 August 11, 2022 15:00
@JingyaHuang JingyaHuang changed the base branch from albertvillanova-patch-1 to main August 11, 2022 15:00
@JingyaHuang JingyaHuang changed the base branch from main to albertvillanova-patch-1 August 11, 2022 15:15
@JingyaHuang JingyaHuang changed the base branch from albertvillanova-patch-1 to main August 11, 2022 15:15
@JingyaHuang
Copy link
Contributor Author

close as it turned to be too messy even after rebasing.

@JingyaHuang JingyaHuang deleted the fix-deberta-tracing branch August 22, 2022 08:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.