Skip to content

Commit

Permalink
Mingyuanm/add back fp8 support to sd (#9070)
Browse files Browse the repository at this point in the history
* update branch

Signed-off-by: eharper <eharper@nvidia.com>

* Add dist ckpt support for regular optimizers (#7749)

* Add dist ckpt support for regular optimizers

Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com>

* [tutorial] fixed missing RIR scripts file. (#8257)

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* fix imports

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* imports fix

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ci imports fix

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert asr notebook

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert asr notebook

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: dimapihtar <dpihtar@gmail.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
Co-authored-by: dimapihtar <dpihtar@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Pin lhotse=1.19.2 in r1.23.0 (#8303)

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Cache Aware Streaming tutorial notebook (#8296)

* add notebook

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* rename old notebook to Buffered_Streaming

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* call setup_streaming_params in set_default_att_context_size method

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* update links in docs

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* update links to tutorials in docs

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* remove hard-coding

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* rename var

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

---------

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* fix path location and branch (#8304)

* fix path location and branch

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* change to a floating point number

Signed-off-by: Nithin Rao Koluguri <nithinraok>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>

* add deallocate pipeline output optimization (#8279)

* add deallocate pipeline output optimization

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix memory leak caused by context parallelism hanging references by omegaconf (#8299)

* save cp_size to self

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

* use parallel_state instead of self

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>

---------

Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* remove assertion (#8302)

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* Update PEFT Doc (#8262)

* update peft doc

Signed-off-by: Chen Cui <chcui@nvidia.com>

* remove old prompt learning doc and notebook

Signed-off-by: Chen Cui <chcui@nvidia.com>

* fix table

Signed-off-by: Chen Cui <chcui@nvidia.com>

* fix table

Signed-off-by: Chen Cui <chcui@nvidia.com>

* fix table

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Merge branch 'r1.23.0' into chcui/update_peft_doc

Signed-off-by: Chen Cui <chcui@nvidia.com>

* revert accidental changes

Signed-off-by: Chen Cui <chcui@nvidia.com>

* revert accidental changes

Signed-off-by: Chen Cui <chcui@nvidia.com>

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Attention encoder-decoder models for multiple speech-to-text tasks  (#8242) (#8324)

* Rebasing canary changes at current main

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Move the changes from asr transformer to nlp transformer as originally intended

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* update eval to strip spaces before punctuations

Signed-off-by: stevehuang52 <heh@nvidia.com>

* update pc strip

Signed-off-by: stevehuang52 <heh@nvidia.com>

* [canary] Refactor: `PromptedAudioToTextLhotseDataset` and `EncDecMultiTaskModel` (#8247)

* Create a separate CanaryDataset and use it inside `transformer_bpe_models.py`. Ditches `token_sequence_format`.

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* [canary] Refactor: move changes in transformer_bpe_models.py to Canar… (#8252)

* [canary] Refactor: move changes in transformer_bpe_models.py to CanaryModel

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Rename `CanaryModel` to `EncDecMultiTaskModel` and remove inheritance from `EncDecTransfModelBPE`; add a separate config for this model

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Rename `CanaryDataset` to `PromptedAudioToTextLhotseDataset`; add `prompt_format_fn` argument; clean-up the `_canary_prompt_format` function a bit

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Move tokenization into `prompt_format_fn`, fix usage, add docs

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Backward-compatible utterance validation

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Improve type annotations

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* config and prompt_fn registration changes from review

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* fix transcribe config

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Refactor Canary to follow schema of remaining ASR models (#8260)

* Initial draft of multi task beam decoding strategy

Signed-off-by: smajumdar <titu1994@gmail.com>

* Stabilize inference

Signed-off-by: smajumdar <titu1994@gmail.com>

* Update AED Multi Task model to mostly conform to Archetype-Type format. Update config

Signed-off-by: smajumdar <titu1994@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add change decoding strategy

Signed-off-by: smajumdar <titu1994@gmail.com>

* Remove redundant imports

Signed-off-by: smajumdar <titu1994@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

Signed-off-by: smajumdar <titu1994@gmail.com>

* Cleanup

Signed-off-by: smajumdar <titu1994@gmail.com>

* remove asr transformer dependency on nlp

Signed-off-by: stevehuang52 <heh@nvidia.com>

* clean up

Signed-off-by: stevehuang52 <heh@nvidia.com>

* copy token_classifier from nlp to asr

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Address comments

Signed-off-by: smajumdar <titu1994@gmail.com>

* Add typing to beam decoding

Signed-off-by: smajumdar <titu1994@gmail.com>

* Make prompt format configurable

Signed-off-by: smajumdar <titu1994@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* drop asr dependency on nlp

Signed-off-by: stevehuang52 <heh@nvidia.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: stevehuang52 <heh@nvidia.com>

* fix transcribe, update asr evaluator

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Extend the docs for the canary prompt_fn

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Incorporate changes from Nithin's code review

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* training bug fix and adding launch script for speech_multitask (#8270)

* bug fix and adding launch script for speech_multitask

Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>

* update launch script example in speech_to_text_aed.py

Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>

---------

Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: Krishna Puvvada <kpuvvada@nvidia.com>

* Fix: drop_last must be true in validation/test otherwise the training will hang

Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>

* revert to current transcribe API

Signed-off-by: stevehuang52 <heh@nvidia.com>

* revert changes to NLP, update docs

Signed-off-by: stevehuang52 <heh@nvidia.com>

* update eval utils

Signed-off-by: stevehuang52 <heh@nvidia.com>

* update docs

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Remove DALI; rename compute_audio_loss to compute_loss

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* set default use_model_transcribe=False

Signed-off-by: stevehuang52 <heh@nvidia.com>

* change os.path.dirname to pathlib

Signed-off-by: stevehuang52 <heh@nvidia.com>

* [canary] Test for CanaryTokenizer + refactoring (#8285)

* Test for CanaryTokenizer

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Attempt at refactor...

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Update config for AED models (#8294)

Signed-off-by: smajumdar <titu1994@gmail.com>

* set default calculate_wer=False in transcribe_speech.py

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Attention encoder-decoder models for multiple speech-to-text tasks

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Apply suggestions from code review, part 1

Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Apply suggestions from code review, part 2

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Document compute_loss

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* update transcribe_speech.py

Signed-off-by: stevehuang52 <heh@nvidia.com>

* add docstring

Signed-off-by: stevehuang52 <heh@nvidia.com>

* Attention encoder-decoder models for multiple speech-to-text tasks

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
Signed-off-by: stevehuang52 <heh@nvidia.com>
Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>
Signed-off-by: Piotr Żelasko <pzelasko@nvidia.com>
Co-authored-by: stevehuang52 <heh@nvidia.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Krishna Puvvada <93558329+krishnacpuvvada@users.noreply.github.com>
Co-authored-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
(cherry picked from commit d10726d)

Co-authored-by: Piotr Żelasko <petezor@gmail.com>

* Multimodal r1.23.0 bug fix  (#8315)

* Rename quick-gelu

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* ddpm config guard

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Fix ddpm edit api

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Fix insert_image_token cfg issue

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* neva updates

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* reformat

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Add back jenkins

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix jenkins

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix bugs

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Update default neva template

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fixes for MoE parameter passing & use of AutoTokenizer/Model for mistral. (#8272)

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP>1 (#8334)

Signed-off-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* Remove asr webapp (#8347)

Signed-off-by: smajumdar <titu1994@gmail.com>

* remove _target_ at model level in aed config (#8351)

Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: Krishna Puvvada <kpuvvada@nvidia.com>

* Add change_vocabulary and save_tokenizers() support to Multitask ASR models (#8357)

* Add change_vocabulary and save_tokenizers() support

Signed-off-by: smajumdar <titu1994@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update nemo/collections/asr/models/aed_multitask_models.py

Co-authored-by: Piotr Żelasko <petezor@gmail.com>
Signed-off-by: Somshubra Majumdar <titu1994@gmail.com>

---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Piotr Żelasko <petezor@gmail.com>

* Change default (#8371)

Signed-off-by: smajumdar <titu1994@gmail.com>

* bug fix in fast-conformer-aed.yaml and adding jenkins test for speech_to_text_aed model (#8368)

Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>

* Enable megatron core loggers for GPT pretraining (#8354)

* Logging changes tested for gpt_pretraining

Signed-off-by: Aishwarya Bhandare <abhandare@nvidia.com>

* Additional args

Signed-off-by: Aishwarya Bhandare <abhandare@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Aishwarya Bhandare <abhandare@nvidia.com>
Co-authored-by: Aishwarya Bhandare <abhandare@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* mcore ds fix (#8283)

* [tutorial] fixed missing RIR scripts file. (#8257)

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* add values to en tts dict (#7879)

Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>

* mcore ds fix

Signed-off-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update mcore

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert asr files

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add comments

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for mcore mock dataset

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update mcore version

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update gpt cfg

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update mcore commit

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix Bert unit tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update bert tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix bert mcore test

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix gpt jenkins tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update apex & TE commits

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert apex installation

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* turn off the fusion for jenkins

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>
Signed-off-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: dimapihtar <dpihtar@gmail.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>

* Add Finetuning tutorial with HF Datasets (#8356)

* Add Finetuning tutorial with HF Datasets

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* update on Som comments

Signed-off-by: Nithin Rao Koluguri <nithinraok>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Nithin Rao Koluguri <nithinraok>

* release updates (#8378)

* [tutorial] fixed missing RIR scripts file. (#8257)

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* add values to en tts dict (#7879)

Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>

* mcore ds fix

Signed-off-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update mcore

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* revert asr files

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add comments

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for mcore mock dataset

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update mcore version

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update gpt cfg

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update mcore commit

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix Bert unit tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* update bert tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix bert mcore test

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* fix gpt jenkins tests

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add support for dict data input type

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add mock ds test

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* add test for dict data input type

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* mcore ds fix

Signed-off-by: dimapihtar <dpihtar@gmail.com>

* data input fix

Signed-off-by: dimapihtar <dpihtar@gmail.com>

---------

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>
Signed-off-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: dimapihtar <dpihtar@gmail.com>
Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>

* MCore dataset compatibility for tokenizers (#8390)

* Add unique_identifiers for all tokenizers and eod for SentencePieceTokenizer

Signed-off-by: Valerie Sarge <vsarge@nvidia.com>

* Add generalized token aliases to TokenizerSpec to conform with MegatronTokenizer's interface. Remove now-redundant individual fixes from AutoTokenizer and SentencePieceTokenizer.

Signed-off-by: Valerie Sarge <vsarge@nvidia.com>

---------

Signed-off-by: Valerie Sarge <vsarge@nvidia.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>

* Mcore customization doc (#8298)

* [tutorial] fixed missing RIR scripts file. (#8257)

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

* add values to en tts dict (#7879)

Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>

* Add Bert HF checkpoint converter (#8088)

* Add Bert HF checkpoint converter

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Reformat

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Add BERT ONNX export

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add NeMo BERT to HF BERT script

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Clean code

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update argument names

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Update build_transformer_config in Bert

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bobby Chen <bobchen@nvidia.com>

* initial placeholder

Signed-off-by: Huiying Li <huiyingl@nvidia.com>

* add to intro/index.rst

Signed-off-by: Huiying Li <huiyingl@nvidia.com>

* initial content update

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* add diff images

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

size

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* minor fixes

* minor language change

Signed-off-by: Chen Cui <chcui@nvidia.com>

* clean changes

---------

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Huiying Li <huiyingl@nvidia.com>
Signed-off-by: Huiying Li <willwin.lee@gmail.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com>
Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Bobby Chen <bobchen@nvidia.com>
Co-authored-by: Huiying Li <huiyingl@nvidia.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>

* wer fix (#8404)

Signed-off-by: Travis Bartley <tbartley@nvidia.com>

* updated link to pubmed (#8402)

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Nithin Rao Koluguri <nithinraok>

* Update NFA video download link (#8406)

* update nfa nasa video link

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* update link in markdown

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

---------

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>

* revert changes (#8410)

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Fix dreambooth data sampler issue (#8400)

* Turn on drop last

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* Some neva fixes

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fixed errors in the CTM gen functions (#8416)

Signed-off-by: Taejin Park <tango4j@gmail.com>

* add ensemble decoding fix (#8427)

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: Nithin Rao Koluguri <nithinraok>

* SDE bugfix log (#8430)

Signed-off-by: George <gzelenfroind@nvidia.com>

* mcore customization doc minor fix (#8421)

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* NeMo-Mistral to HF converter bugfix. (#8353)

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Fixing mcore bert for TP, PP and SP (#8336)

* Fixing mcore bert for TP, PP and SP

* Fixing mcore bert for TP, PP and SP

* Fixing mcore version

* Fixing mcore version

* Update Jenkinsfile

Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>

* Update Jenkinsfile

Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>

* Update Jenkinsfile

Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>

---------

Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>
Co-authored-by: Shanmugam Ramasamy <shanmugamr@shanmugamr-mlt.client.nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* Add settings to suppress bf16 compile errors in CI on V100 (#8481)

* Add settings to suppress bf16 compile errors in CI on V100

Signed-off-by: Abhishree <abhishreetm@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Abhishree <abhishreetm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* MoE parameter passing (#8255)

* MoE parameter passing

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Pass EP/MoE params in consumer scripts.

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* PR fixes

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Use latest commit of mcore-0.5

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* CI fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Co-authored-by: Alexandros Koumparoulis <akoumparouli@dgx1v-loki-21.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update k2 version (#8478) (#8492)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add fp8 support for SD/Update notebook paths (#8489)

* Add fp8 support for SD/Update notebook paths

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>

* pin to 0.5.0 (#8465)

Signed-off-by: eharper <eharper@nvidia.com>

* Update NeMo Multimodal Requirements (#8515)

* Update requirements_multimodal.txt

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* update github raw content link (#8517)

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Add dep notice for notebooks (#8522)

* add dep notice

Signed-off-by: eharper <eharper@nvidia.com>

* revert

Signed-off-by: eharper <eharper@nvidia.com>

---------

Signed-off-by: eharper <eharper@nvidia.com>

* Revert FP8 integration (#8520)

* Revert FP8 integration

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update data prep notebook (#8532)

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* Add back fp8 support

* SD-FP8: fix the bug of normalization location

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* map potential FP8 ckpt to FP16

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* Add TE fp8 training

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* Only overwrite unet precision when self.megatron_amp_O2 is true

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* New structure is now compatible with old ckpts

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* Add support on mapping old unet checkpoint to new structure and FP8 structure

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Sync with main branch

Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>

---------

Signed-off-by: eharper <eharper@nvidia.com>
Signed-off-by: Mikołaj Błaż <mblaz@nvidia.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: dimapihtar <dpihtar@gmail.com>
Signed-off-by: Piotr Żelasko <petezor@gmail.com>
Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: Sangkug Lym <slym@nvidia.com>
Signed-off-by: smajumdar <titu1994@gmail.com>
Signed-off-by: Krishna Puvvada <kpuvvada@nvidia.com>
Signed-off-by: Somshubra Majumdar <titu1994@gmail.com>
Signed-off-by: Aishwarya Bhandare <abhandare@nvidia.com>
Signed-off-by: Mariana Graterol Fuenmayor <marianag@nvidia.com>
Signed-off-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
Signed-off-by: Valerie Sarge <vsarge@nvidia.com>
Signed-off-by: Huiying Li <huiyingl@nvidia.com>
Signed-off-by: Huiying Li <willwin.lee@gmail.com>
Signed-off-by: Travis Bartley <tbartley@nvidia.com>
Signed-off-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: George <gzelenfroind@nvidia.com>
Signed-off-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>
Signed-off-by: Abhishree <abhishreetm@gmail.com>
Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: Mingyuan Ma <mingyuanm@nvidia.com>
Co-authored-by: eharper <eharper@nvidia.com>
Co-authored-by: mikolajblaz <mikolajblaz@users.noreply.github.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com>
Co-authored-by: dimapihtar <dpihtar@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Piotr Żelasko <petezor@gmail.com>
Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: Chen Cui <chcui@nvidia.com>
Co-authored-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com>
Co-authored-by: akoumpa <153118171+akoumpa@users.noreply.github.com>
Co-authored-by: Sangkug Lym <slym@nvidia.com>
Co-authored-by: Krishna Puvvada <93558329+krishnacpuvvada@users.noreply.github.com>
Co-authored-by: Krishna Puvvada <kpuvvada@nvidia.com>
Co-authored-by: ashbhandare <ash.bhandare@gmail.com>
Co-authored-by: Aishwarya Bhandare <abhandare@nvidia.com>
Co-authored-by: Mariana <47233618+mgrafu@users.noreply.github.com>
Co-authored-by: Dmytro Pykhtar <dpykhtar@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
Co-authored-by: Valerie Sarge <vsarge@nvidia.com>
Co-authored-by: Huiying <willwin.lee@gmail.com>
Co-authored-by: Bobby Chen <bobchen@nvidia.com>
Co-authored-by: Huiying Li <huiyingl@nvidia.com>
Co-authored-by: tbartley94 <90423858+tbartley94@users.noreply.github.com>
Co-authored-by: Taejin Park <tango4j@gmail.com>
Co-authored-by: George <37293288+Jorjeous@users.noreply.github.com>
Co-authored-by: Shanmugam Ramasamy <111910568+shanmugamr1992@users.noreply.github.com>
Co-authored-by: Shanmugam Ramasamy <shanmugamr@shanmugamr-mlt.client.nvidia.com>
Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com>
Co-authored-by: Alexandros Koumparoulis <akoumparouli@dgx1v-loki-21.nvidia.com>
Co-authored-by: Vladimir Bataev <vbataev@nvidia.com>
Co-authored-by: Mengdi Wang <didow@nvidia.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
  • Loading branch information
1 parent fb883ed commit d226017
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ model:
precision: ${trainer.precision}
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 1 # limited by GPU memory
global_batch_size: 1 # will use more micro batches to reach global batch size
micro_batch_size: 16 # limited by GPU memory
global_batch_size: 16 # will use more micro batches to reach global batch size
native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16


Expand Down Expand Up @@ -97,15 +97,15 @@ model:
unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained: #/ckpts/nemo-v1-2.ckpt
from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
Expand All @@ -121,6 +121,7 @@ model:
use_flash_attention: True
unet_precision: fp32
resblock_gn_groups: 32
use_te_fp8: False

first_stage_config:
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL
Expand All @@ -140,30 +141,30 @@ model:
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
restore_from_path: /ckpts/openai.nemo
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
version: openai/clip-vit-large-patch14
device: cuda
freeze: True
layer: "last"
# For compatibility of history version that uses HF clip model
# _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
# version: openai/clip-vit-large-patch14
# device: cuda
# max_length: 77
max_length: 77
# _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder
# restore_from_path: /ckpts/openai-old.nemo
# device: cuda
# freeze: True
# layer: "last"



# miscellaneous
seed: 1234
resume_from_checkpoint: null # manually set the checkpoint file to load from
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
ddp_overlap: True # True for using PyTorch DDP overlap.
ddp_overlap: False # True for using PyTorch DDP overlap.

optim:
name: fused_adam
Expand Down Expand Up @@ -191,7 +192,7 @@ model:
synthetic_data_length: 10000
train:
dataset_path:
- /datasets/coyo/test.pkl
- /datasets/coyo/wdinfo/coyo-700m/wdinfo-selene.pkl
augmentations:
resize_smallest_side: 512
center_crop_h_w: 512, 512
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def model_cfg_modifier(model_cfg):
model_cfg.unet_config.use_flash_attention = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
model_cfg.first_stage_config._target_ = (
'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKL'
)

torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
# megatron_amp_O2 is not yet supported in diffusion models
self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False)

if self.cfg.precision in ['16', 16, 'bf16']:
if self.megatron_amp_O2 and self.cfg.precision in ['16', 16, 'bf16']:
self.model_parallel_config.enable_autocast = False
if not hasattr(self.cfg.unet_config, 'unet_precision') or not '16' in str(
self.cfg.unet_config.unet_precision
Expand Down
66 changes: 53 additions & 13 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
from inspect import isfunction

import torch
Expand All @@ -21,6 +22,13 @@
from torch import einsum, nn
from torch._dynamo import disable

if os.environ.get("USE_NATIVE_GROUP_NORM", "0") == "1":
from nemo.gn_native import GroupNormNormlization as GroupNorm
else:
from apex.contrib.group_norm import GroupNorm

from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
Expand Down Expand Up @@ -96,13 +104,19 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0, use_te=False):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)

self.net = nn.Sequential(project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))
if use_te:
activation = 'gelu' if not glu else 'geglu'
# TODO: more parameters to be confirmed, dropout, seq_length
self.net = LayerNormMLP(hidden_size=dim, ffn_hidden_size=inner_dim, activation=activation,)
else:
norm = nn.LayerNorm(dim)
project_in = nn.Sequential(LinearWrapper(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(norm, project_in, nn.Dropout(dropout), LinearWrapper(inner_dim, dim_out))

def forward(self, x):
return self.net(x)
Expand Down Expand Up @@ -225,10 +239,15 @@ def __init__(
dropout=0.0,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()

self.inner_dim = dim_head * heads
if context_dim is None:
self.is_self_attn = True
else:
self.is_self_attn = False # cross-attention
context_dim = default(context_dim, query_dim)
# make attention part be aware of self-attention/cross-attention
self.context_dim = context_dim
Expand All @@ -238,10 +257,19 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)

self.use_te = use_te
if use_te:
return_layernorm_output = True if self.is_self_attn else False
self.norm_to_q = LayerNormLinear(
query_dim, self.inner_dim, bias=False, return_layernorm_output=return_layernorm_output
)
else:
self.norm = nn.LayerNorm(query_dim)
self.to_q = LinearWrapper(query_dim, self.inner_dim, bias=False)

self.to_out = nn.Sequential(
LinearWrapper(self.inner_dim, query_dim, lora_network_alpha=lora_network_alpha), nn.Dropout(dropout)
)
Expand All @@ -262,8 +290,18 @@ def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_cr
# add additional token
x = torch.cat([additional_tokens, x], dim=1)

q = self.to_q(x)
context = default(context, x)
if self.use_te:
q_out = self.norm_to_q(x)
if self.is_self_attn:
q, ln_out = q_out
context = default(context, ln_out)
else:
q = q_out
context = default(context, x)
else:
x = self.norm(x)
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)

Expand Down Expand Up @@ -351,6 +389,7 @@ def __init__(
use_flash_attention=False,
disable_self_attn=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
self.disable_self_attn = disable_self_attn
Expand All @@ -362,8 +401,9 @@ def __init__(
use_flash_attention=use_flash_attention,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, use_te=use_te)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
Expand All @@ -372,10 +412,8 @@ def __init__(
dropout=dropout,
use_flash_attention=use_flash_attention,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.use_checkpoint = use_checkpoint

def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
Expand All @@ -397,15 +435,15 @@ def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_at
def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
x = (
self.attn1(
self.norm1(x),
x,
context=context if self.disable_self_attn else None,
additional_tokens=additional_tokens,
n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0,
)
+ x
)
x = self.attn2(self.norm2(x), context=context, additional_tokens=additional_tokens) + x
x = self.ff(self.norm3(x)) + x
x = self.attn2(x, context=context, additional_tokens=additional_tokens) + x
x = self.ff(x) + x
return x


Expand All @@ -431,6 +469,7 @@ def __init__(
use_checkpoint=False,
use_flash_attention=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()
logging.info(
Expand Down Expand Up @@ -473,6 +512,7 @@ def __init__(
use_flash_attention=use_flash_attention,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
)
for d in range(depth)
]
Expand Down
Loading

0 comments on commit d226017

Please sign in to comment.