From 8d4218e0b3245986a9d285ecdad9e9f9d2c6d6e3 Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Wed, 10 Jan 2024 22:32:19 -0800 Subject: [PATCH] Add All Multimodal Source Code Part 2: Text to image, x to nerf (#7970) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update README.md: output_path --> output_manifest_filepath (#7442) Signed-off-by: Samuele Cornell * Updating FlashAttention API to match FlashAttentionV2 * Multiple fixes for mm * Fix CI inductor issue and update to torch compile * Remove suppress error * Fix when conversion config uses fp16 and it complains about precision plugin * Fixing FAv2 API usage * Initial release of content filtering model * Added synthetic dataloader for precached and online mode * Mingyuanm/dreambooth opt * Add llama2 support in neva training * Fix sampler length * Fix all precision issues in nemo multimodal * Add rope dynamic linear scaling (#7437) * Add dynamic linear scaling Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yang Zhang * Fix None dataloader issue in PTL2.0 (#7455) * Fix None dataloader issue in PTL2.0 Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: KunalDhawan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ASR] Confidence measure -> method renames (#7434) * measure -> method Signed-off-by: Aleksandr Laptev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Add steps for document of getting dataset 'SF Bilingual Speech' (#7378) * Add steps for document of getting dataset 'SF Bilingual Speech' Signed-off-by: Robin Dong * Update datasets.rst added a link from a tutorial demonstrating detailed data prep steps. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * RNN-T confidence and alignment bugfix (#7381) * new frame_confidence and alignments lists are now always created after the while loop Signed-off-by: Aleksandr Laptev * tests added Signed-off-by: Aleksandr Laptev --------- Signed-off-by: Aleksandr Laptev * Fix resume from checkpoint in exp_manager (#7424) (#7426) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Eric Harper * Fix checking of cuda/cpu device for inputs of Decoder (#7444) * Fix checking of cuda/cpu device for inputs of Decoder Signed-off-by: Robin Dong * Update tacotron2.py Signed-off-by: Jason --------- Signed-off-by: Robin Dong Signed-off-by: Jason Co-authored-by: Jason * Fix failure of ljspeech's get_data.py (#7430) * Fix failure of ljspeech's get_data.py Signed-off-by: Robin Dong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [TTS] Fix audio codec type checks (#7373) * [TTS] Fix audio codec type checks Signed-off-by: Ryan * [TTS] Fix audio codec tests Signed-off-by: Ryan --------- Signed-off-by: Ryan * [TTS] Add dataset to path of logged artifacts (#7462) * [TTS] Add dataset to path of logged artifacts Signed-off-by: Ryan * [TTS] Revert axis name back to Audio Frames Signed-off-by: Ryan --------- Signed-off-by: Ryan * Fix sft dataset truncation (#7464) * Add fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Automatic Lip Reading Recognition (ALR) - ASR/CV (Visual ASR) (#7330) * striding_conv1d_k5 and dw_striding_conv1d_k5 subsampling Signed-off-by: mburchi * transpose conv1d inputs Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: mburchi * Update subsampling.py change striding_conv1d_k5 to striding_conv1d Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * video manifest Signed-off-by: mburchi * add collection classes Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test_step_outputs Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * clean references Signed-off-by: mburchi * freeze unfreeze transcribe cv models Signed-off-by: mburchi * correct manifest get_full_path bug Signed-off-by: mburchi * update for PR Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * guard torchvision Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * _video_speech_collate_fn in cv/data/video_to_text.py Signed-off-by: mburchi * add self.out = None to asr subsampling Signed-off-by: mburchi * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv -> multimodal/speech_cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: mburchi Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Gitman * HF StarCoder to NeMo conversion script (#7421) * Script to convert HF StarCoder checkpoint to NeMo Signed-off-by: Jan Lasek * StarCoder conversion test Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jan Lasek * Fix test Signed-off-by: Jan Lasek * Catch up with save_to changes Signed-off-by: Jan Lasek * Don't abbreviate args for clarity Signed-off-by: Jan Lasek * Configurable precision: BF16 vs FP32 Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Lasek Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix bug when loading dist ckpt in peft (#7452) Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu * Fix adding positional embeddings in-place in transformer module (#7440) Signed-off-by: Tamerlan Tabolov Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * Fix (#7478) Signed-off-by: Cheng-Ping Hsieh * add sleep (#7498) (#7499) * add sleep * add sleep onto config instead * add comment --------- Signed-off-by: Gerald Shen Co-authored-by: Gerald Shen <119401249+gshennvm@users.noreply.github.com> * Fix exp manager check for sleep (#7503) (#7504) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar * bugfix: trainer.accelerator=auto from None. (#7492) (#7493) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * [doc] fix broken link (#7481) Signed-off-by: Stas Bekman * [TTS] Read audio as int32 to avoid flac read errors (#7477) * [TTS] Read audio as int32 to avoid flac read errors Signed-off-by: Ryan * [TTS] Add comment about read failures Signed-off-by: Ryan --------- Signed-off-by: Ryan * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS (#7409) * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS * Train 'AISHELL-3' dataset with multi-speakers Signed-off-by: Robin Dong * Update get_data.py update copyright header Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update get_data.py added a disclaimer Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add new configuration file for AISHELL3 with multispeaker of fastpitch Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * dllogger - log on rank 0 only (#7513) Signed-off-by: Stas Bekman * Fix TTS FastPitch tutorial (#7494) (#7516) * Fix --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * Fix get_dist() tensor dimension (#7506) (#7515) Signed-off-by: Jocelyn Huang Co-authored-by: Jocelyn * bugfix: specify trainer.strategy=auto when devices=1 (#7509) (#7512) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * fix (#7511) Signed-off-by: Abhinav Khattar * [TTS] Fix FastPitch data prep tutorial (#7524) Signed-off-by: Ryan * add italian tokenization (#7486) * add italian tokenization Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more ipa lexicon it Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix error deletion Signed-off-by: GiacomoLeoneMaria * add test Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Replace None strategy with auto in tutorial notebooks (#7521) (#7527) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * unpin setuptools (#7534) (#7535) Signed-off-by: fayejf <36722593+fayejf@users.noreply.github.com> Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> * remove auto generated examples (#7510) * explicitly remove autogenerated examples for data parallel evaluation Signed-off-by: arendu * mark autogenrated and remove it for test Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Add the `strategy` argument to `MegatronGPTModel.generate()` (#7264) It is passed as an explicit argument rather than through `**strategy_args` so as to ensure someone cannot accidentally pass other arguments that would end up being ignored. It is a keyword-only argument to ensure that if in the future we want to update the signature to `**strategy_args`, we can do it without breaking code. Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> * Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue (#7531) (#7533) * fix none dataloader issue ptl2 * ptl2.0 logging fixes for rnnt_models --------- Signed-off-by: KunalDhawan Co-authored-by: Kunal Dhawan Co-authored-by: Nithin Rao * gpus -> devices (#7542) (#7545) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao * Update FFMPEG version to fix issue with torchaudio (#7551) (#7553) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar * PEFT GPT & T5 Refactor (#7308) * initial implementation of add_adapters API * correct type hint * Add config in add_adapters for save and load (@author bobchen) * Remove AdapterConfig to avoid import error * Add AdaterConfig back and move adaptermixin to sft model * Add NLPSaveRestoreConnector as default in NLPModel.restore_from * Add restore_from_nemo_with_adapter and test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename t5 file and classes to be consistent with GPT * add t5 sft dataset * add support for single-file format with T5SFTDataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Various small changes to make T5 SFT work like GPT SFT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add adapter evaluation test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add MultiAdaterConfig for ia3 and fix builder issue * Make ptuning for T5SFTModel work using mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add IA3_Adapter for AdapterName * Add adapter name for ptuning and attention adapter * Make test script GPT/T5 agnostic * Add layer selection feature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Integrate adapter name and config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update gpt peft tuning script to new API * add t5 peft tuning script with new API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix IA3 layer selection issue * Override state_dict on SFT model instead of mixin * Add load adapter by adapter config * move peft config map away from example script * auto get config from nemo adapter * Move PEFTConfig to new file * fix ckpt save/load for t5 * name change: add_adapters -> add_adapter * variable name change * update t5 script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix t5 issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add weight tying * update gpt tuning script * PEFT-API proposal * Fix according to comments * update tuning scripts * move merge_cfg_with to mixin class since it applies to both gpt and t5 and requires the model class for restore * Add mcore_gpt support for NLPAdapterMixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * variable name change to distinguish "peft" and "adapter" * override `load_adapters` to support `add_adapter` name change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tuning and eval script for adapter save/load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Ptuning on first stage only * add lora tutorial for review * Fix layer selection for mcore * add landing page * fix resume training Signed-off-by: jasonwan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mcore condition in sharded_state_dict to make sft work * Update lora_tutorial.md First edit of this file for PEFT documentation for NeMO Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * rename Adapter to AttentionAdapter to avoid confusion in doc * Change load_adapters to load .nemo * add quick start guide * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add load_adapters with .ckpt * Remove setup_complete changes in load_adapters * update landing page * remove typo * Updated quick_start.md per Chen Cui Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * Add inference config merger and tutorial * Add doc string for NLPAdapterModelMixin and deprecated warning on MegatronGPTPEFTModel * add supported_methods.md and update other documentations * Update supported_methods.md minor updates. Signed-off-by: Adi Renduchintala * Update landing_page.md minor update. Signed-off-by: Adi Renduchintala * Modify doc string for NLPAdapterModelMixin * Add doc string add_adapters in NLPAdapterModelMixin * rename canonical adapters * remove mcore hard dependency * [PATCH] move microbatch calculator to nemo from apex * remove apex dependency in gpt and t5 sft models * remove apex dependency in gpt model * render doc strings * fix * Add missing virtual_tokens on ptuning * fix docstrings * update gpt-style model coverage in docs * update docstring * Remove pdb * add lightning_fabric to make docstring rendering work * Add Ptuning missing key * try docstring rendering * Fix ptuning issue * update gpt t5 peft tuning and eval scripts * typos * update eval config * fix bug relating to apex dependency removal * typo * make predict step behave the same as test step * make lora tutorial work in notebook * cosmetics * update yaml scripts * mcore_gpt attribute optional * typo * update eval scripts and fix T5 eval bugs * add NLPDDPStrategyNotebook and trainer builder logic to use it * update lora notebook to use new trainer builder * fix microbatch calculator bug for inference after training * Convert markdown files to RST and incorporate with doc * typo * revise language * remove extra cell * remove unnecessary inheritance * remove old tests * move layer selection default so logging messages make sense * remove `save_adapters` as adapter weights are saved automatically during training * initialize weights from a checkpoint instead of randomly * multiple fields can form a context (#7147) * list of context fields and flexible prompt template Signed-off-by: arendu * list of fields for context Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Fix bug Signed-off-by: Cheng-Ping Hsieh * Add multiple truncation fields and middle truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Compatible to old ckpt Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix tokenize detokenize issue Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove detokenization, add truncation augmentation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve comments Signed-off-by: Cheng-Ping Hsieh * Remove unused import Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert eos Signed-off-by: Cheng-Ping Hsieh * Add tokenizer space_sensitive attribute Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix error Signed-off-by: Cheng-Ping Hsieh * Fix erorr and use re Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Change assert logic Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Follow adi suggestion Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove merge function Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add example and comment Signed-off-by: Cheng-Ping Hsieh * Remove context_key and add comment Signed-off-by: Cheng-Ping Hsieh * Remove random truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix template none Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * revert config changes * remove accidental breakpoint * support TP>1 loading * infer adapter type from checkpoint in during eval * breakup add adapter * enable interpolation of train_ds and validation_ds * update metric calc script to conform to single-file eval format * remove extraneous print * update lora notebook for updated merge_inference_cfg * Update nlp_adapter_mixins.py variable name change Signed-off-by: Chen Cui * turn off grad scaler for PP to match old scripts * remove PEFTSaveRestoreConnector since functionality all covered by the new mixin class * remove resume_from_checkpoint check since covered in #7335 * revert changes made in eval config interpolation * more interpolation * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dup line Signed-off-by: Chen Cui * code style warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix config mistake Signed-off-by: Chen Cui * add copyright header Signed-off-by: Chen Cui * fix code check warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert changes to remove apex dependency (mixed apex+nemo microbatch calculator broke some CI tests) Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * consolidate peft and sft scripts Signed-off-by: Chen Cui * update CI tests Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * notebook branch points to main to prepare for merge Signed-off-by: Chen Cui * fix gpt and t5 validation with any metric other than loss Signed-off-by: Chen Cui * support pre-extracted checkpoints Signed-off-by: Chen Cui --------- Signed-off-by: jasonwan Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Chen Cui Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marc Romeyn Co-authored-by: jasonwan Co-authored-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Yuanzhe Dong Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * fix a typo (#7496) Signed-off-by: BestJuly * [TTS] remove curly braces from ${BRANCH} in jupyer notebook cell. (#7554) (#7560) * remove curly braces. * remove installation of pynini. --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * add youtube embed url (#7570) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * Remap speakers to continuous range of speaker_id for dataset AISHELL3 (#7536) * Remap speakers to continuous range of speaker_id for dataset AISHELL3 * Add new key/value pair to record raw speaker for AISHELL3 dataset Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix validation_step_outputs initialization for multi-dataloader (#7546) (#7572) * added correct validation_step_outputs initialization for mutli-dataloader * changed kernel for display * Update logic for validation and test step outputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert multidataloader changes in multilang ASR notebook --------- Signed-off-by: KunalDhawan Signed-off-by: smajumdar Co-authored-by: Kunal Dhawan Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Append output of val step to self.validation_step_outputs (#7530) (#7532) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * [TTS] fixed trainer's accelerator and strategy. (#7569) (#7574) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> * Append val/test output to instance variable in EncDecSpeakerLabelModel (#7562) (#7573) * Append val/test output to the instance variable in EncDecSpeakerLabelModel * Handle test case in evaluation_step * Replace type with isinstance --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * Fix CustomProgressBar for resume (#7427) (#7522) * Fix CustomProgress Bar for resume and multiple epochs * Edit num_training_batches * Use max_steps as total for progress bar for resume * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * fix typos in nfa and speech enhancement tutorials (#7580) (#7583) Signed-off-by: Elena Rastorgueva Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> * Add strategy as ddp_find_unused_parameters_true for glue_benchmark.py (#7454) (#7461) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * update strategy (#7577) (#7578) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao * Fix typos (#7581) * Change hifigan finetune strategy to ddp_find_unused_parameters_true (#7579) (#7584) * Change strategy to auto --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * [BugFix] Add missing quotes for auto strategy in tutorial notebooks (#7541) (#7548) * Add missing quotes for auto strategy * Revert trainer.gpus to trainer.devices in Self_Supervised_Pre_Training.ipynb --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * add build os key (#7596) (#7599) * add build os key * add tools * update to stable version --------- Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao * StarCoder SFT test + bump PyT NGC image to 23.09 (#7540) * Add SFT StarCoder test Signed-off-by: Jan Lasek * Remove _modify_config call as it is covered in load_from_nemo just below Signed-off-by: Jan Lasek * Test with pyt:23.09 container Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek * defaults changed (#7600) * defaults changed Signed-off-by: arendu * typo Signed-off-by: arendu * update Signed-off-by: arendu --------- Signed-off-by: arendu * add ItalianPhonemesTokenizer (#7587) * add ItalianPhonemesTokenizer Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Italian phonemes Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test Signed-off-by: GiacomoLeoneMaria --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * best ckpt fix (#7564) (#7588) Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> * Add files via upload (#7598) specifies the branch Signed-off-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Fix validation in G2PModel and ThutmoseTaggerModel (#7597) (#7606) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Safeguard nemo_text_processing installation on ARM (#7485) * safeguard nemo_text_processing installing Signed-off-by: Jason * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update check Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Bound transformers version in requirements (#7620) Signed-off-by: Abhishree * fix llama2 70b lora tuning bug (#7622) * fix llama2 70b lora tuning bug Signed-off-by: Chen Cui * Update peft_config.py brackets Signed-off-by: Adi Renduchintala --------- Signed-off-by: Chen Cui Signed-off-by: Adi Renduchintala Co-authored-by: Adi Renduchintala * Fix import error no module name model_utils (#7629) Signed-off-by: Mehadi Hasan Menon * add fc large ls models (#7641) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Koluguri * bugfix: trainer.gpus, trainer.strategy, trainer.accelerator (#7621) (#7642) * [TTS] bugfix for Tacotron2 tutorial due to PTL 2.0 * trainer.gpus -> trainer.devices * fixed related tutorial bugs --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * fix ssl models ptl monitor val through logging (#7608) (#7614) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Co-authored-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Fix metrics for SE tutorial (#7604) (#7612) Signed-off-by: Ante Jukić Co-authored-by: anteju <108555623+anteju@users.noreply.github.com> * Add ddp_find_unused_parameters=True and change accelerator to auto (#7623) (#7644) * Add ddp_find_unused_parameters=True and change acclerator to auto * Add ddp_find_unused_parameters True for normalization_as_tagging_train.py --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * Fix py3.11 dataclasses issue (#7616) * Fix py3.11 dataclasses issue (#7582) * Update ASR configs to support Python 3.11 Signed-off-by: smajumdar * Update TTS configs to support Python 3.11 Signed-off-by: smajumdar * Guard MeCab and Ipadic Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix remaining ASR dataclasses Signed-off-by: smajumdar * Fix remaining ASR dataclasses Signed-off-by: smajumdar * Fix scripts Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update name to ConfidenceMethodConfig Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Safeguard nemo_text_processing installation on ARM (#7485) * safeguard nemo_text_processing installing Signed-off-by: Jason * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update check Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Fix changes to confidence measure Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Signed-off-by: Sangkug Lym Signed-off-by: Jason Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sangkug Lym Co-authored-by: Jason * [Stable Diffusion/ControlNet] Enable O2 training for SD and Fix ControlNet CI failure * Mingyuanm/dreambooth fix * Fix NeMo CI Infer Issue * DreamFusion * Move neva export changes * Add Imagen Synthetic Dataloader * Add VITWrapper and export stuff to wrapper * Update neva with megatron-core support * Fix issues with Dockerfile (#7650) (#7652) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar * [ASR] RNN-T greedy decoding max_frames fix for alignment and confidence (#7635) * decoding and test fix Signed-off-by: Aleksandr Laptev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [ASR] Fix type error in jasper (#7636) (#7653) Signed-off-by: Ryan Co-authored-by: Ryan Langman * [TTS] Add STFT and SI-SDR loss to audio codec recipe (#7468) * [TTS] Add STFT and SI-SDR loss to audio codec recipe Signed-off-by: Ryan * [TTS] Fix STFT resolution Signed-off-by: Ryan * [TTS] Fix training metric logging Signed-off-by: Ryan * [TTS] Add docstring to mel and stft losses Signed-off-by: Ryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ryan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Create per.py (#7538) * Move model precision copy (#7336) * move cfg precision set to megatron base model Signed-off-by: Maanu Grover * remove copy from other models Signed-off-by: Maanu Grover * modify attribute not arg Signed-off-by: Maanu Grover * fix gpt model test for ptl 2.0 Signed-off-by: Maanu Grover * rename function and add docstring Signed-off-by: Maanu Grover * replace precision to dtype conditionals with func call Signed-off-by: Maanu Grover * unnecessary function and cfg reset Signed-off-by: Maanu Grover * set default value Signed-off-by: Maanu Grover * fix precision lookup in a few more places Signed-off-by: Maanu Grover * rename mapping function Signed-off-by: Maanu Grover * ununsed import Signed-off-by: Maanu Grover * save torch datatype to model Signed-off-by: Maanu Grover * set weights precision wrt amp o2 Signed-off-by: Maanu Grover * Revert "set weights precision wrt amp o2" This reverts commit 313a4bfe5eb69d771a6d2433898c0685836aef5c. Signed-off-by: Maanu Grover * revert half precision at inference attempt Signed-off-by: Maanu Grover * move autocast dtype to base model Signed-off-by: Maanu Grover * move params dtype to base model, enable fp16 O2 inf Signed-off-by: Maanu Grover * unused imports Signed-off-by: Maanu Grover --------- Signed-off-by: Maanu Grover Signed-off-by: Sasha Meister * Fix PEFT checkpoint loading (#7388) * Fix PEFT checkpoint loading Signed-off-by: Jason Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jason Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Use distributed optimizer support for multiple dtypes (#7359) * Update distopt wrapper with multiple dtype support Remove manual handling of separate FP32 optimizer. Signed-off-by: Tim Moon * Use distopt support for contiguous buffers with multiple dtypes Signed-off-by: Tim Moon * Fix typo Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Separate distopt buckets for first GPT layer and non-overlapped params Signed-off-by: Tim Moon * Add distopt logic for int dtypes Signed-off-by: Tim Moon * Update Apex commit Signed-off-by: Tim Moon * Remove unused variables Signed-off-by: Tim Moon * Update Apex commit in README and Jenkensfile Signed-off-by: Tim Moon * Debug Dockerfile and Jenkinsfile Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * minor fix for llama ckpt conversion script (#7387) * minor fix for llama ckpt conversion script Signed-off-by: Jason Wang * Update Jenkinsfile Signed-off-by: Jason Wang * remove fast_swiglu configuration Signed-off-by: Jason Wang --------- Signed-off-by: Jason Wang Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * Fix wrong calling of librosa.get_duration() in notebook (#7376) Signed-off-by: Robin Dong Co-authored-by: Somshubra Majumdar Signed-off-by: Sasha Meister * [PATCH] PEFT import mcore (#7393) * [PATCH] PEFT import mcore Signed-off-by: Jason Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jason Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Create per.py Script for calculation Punctuation Error Rate and related rates (correct rate, deletions rate, etc.) Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * [TTS] Added a callback for logging initial data (#7384) Signed-off-by: Ante Jukić Signed-off-by: Sasha Meister * Update Core Commit (#7402) * Update Core Commit Signed-off-by: Abhinav Khattar * update commit Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar Signed-off-by: Sasha Meister * Use cfg attribute in bert (#7394) * use cfg attribute instead of arg Signed-off-by: Maanu Grover * use torch_dtype in place of cfg.precision Signed-off-by: Maanu Grover * move precision copy before super constructor Signed-off-by: Maanu Grover * use trainer arg Signed-off-by: Maanu Grover --------- Signed-off-by: Maanu Grover Signed-off-by: Sasha Meister * Add support for bias conversion in Swiglu models (#7386) * Add support for bias conversion in Swiglu models Signed-off-by: smajumdar * Add support for auto extracting tokenizer model Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for auto extracting tokenizer model Signed-off-by: smajumdar * Fix issue with missing tokenizer Signed-off-by: smajumdar * Refactor Signed-off-by: smajumdar * Refactor Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Update save_to and restore_from for dist checkpointing (#7343) * add dist ckpt to save to, in progress Signed-off-by: eharper * move dist ckpt Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update restore from, need to figure out how to initialize distributed Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * launch distrib if needed when restoring dist ckpt Signed-off-by: eharper * when using mcore we can change tp pp on the fly Signed-off-by: eharper * add load_from_checkpoint support for dist ckpt Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update llama convert script to save dist .nemo Signed-off-by: eharper * fix load dist ckpt Signed-off-by: jasonwan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * setup TE TP groups if needed Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * setup te tp groups if needed Signed-off-by: eharper * remove import Signed-off-by: eharper --------- Signed-off-by: eharper Signed-off-by: jasonwan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jasonwan Signed-off-by: Sasha Meister * fix forward for with mcore=false (#7403) Signed-off-by: Jimmy Zhang Co-authored-by: Jimmy Zhang Signed-off-by: Sasha Meister * Fix logging to remove 's/it' from progress bar in Megatron models and add train_step_timing (#7374) * Add CustomProgressBar class to exp_manager and trainer callbacks Signed-off-by: Abhishree * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the progress bar to reflect total microbatch cnt Signed-off-by: Abhishree * Modify CustomProgressBar class 1) Modify CustomProgressBar class to update progress bar per global_step instead of per microbatch 2) Add the callback to other megatron training/finetuning files that are not using MegatronTrainerBuilder Signed-off-by: Abhishree * Add CustomProgressBar callback to tuning files Signed-off-by: Abhishree * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Set Activation Checkpointing Defaults (#7404) * Set Activation Checkpointing Defaults Signed-off-by: Abhinav Khattar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * check for None Signed-off-by: Abhinav Khattar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhinav Khattar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * make loss mask default to false (#7407) Signed-off-by: eharper Signed-off-by: Sasha Meister * Add dummy userbuffer config files (#7408) Signed-off-by: Sangkug Lym Signed-off-by: Sasha Meister * add missing ubconf files (#7412) Signed-off-by: Abhinav Khattar Signed-off-by: Sasha Meister * New tutorial on Speech Data Explorer (#7405) * Added Google Colab based tutorial on Speech Data Explorer Signed-off-by: George Zelenfroynd Signed-off-by: Sasha Meister * Update ptl training ckpt conversion script to work with dist ckpt (#7416) * update ptl convert script Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * don't break legacy Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: eharper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Allow disabling sanity checking when num_sanity_val_steps=0 (#7413) * Allow disabling sanity checking when num_sanity_val_steps=0 Signed-off-by: Abhishree * Update num_sanity_val_steps to be a multiple of num_microbatches Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Add comprehensive error messages (#7261) Signed-off-by: Anton Peganov Signed-off-by: Sasha Meister * check NEMO_PATH (#7418) Signed-off-by: Nikolay Karpov Signed-off-by: Sasha Meister * layer selection for ia3 (#7417) * layer selection for ia3 Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix missing pip package 'einops' (#7397) Signed-off-by: Robin Dong Signed-off-by: Sasha Meister * Fix failure of pyaudio in Google Colab (#7396) Signed-off-by: Robin Dong Signed-off-by: Sasha Meister * Update README.md: output_path --> output_manifest_filepath (#7442) Signed-off-by: Samuele Cornell Signed-off-by: Sasha Meister * Add rope dynamic linear scaling (#7437) * Add dynamic linear scaling Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yang Zhang Signed-off-by: Sasha Meister * Fix None dataloader issue in PTL2.0 (#7455) * Fix None dataloader issue in PTL2.0 Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: KunalDhawan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * [ASR] Confidence measure -> method renames (#7434) * measure -> method Signed-off-by: Aleksandr Laptev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Add steps for document of getting dataset 'SF Bilingual Speech' (#7378) * Add steps for document of getting dataset 'SF Bilingual Speech' Signed-off-by: Robin Dong * Update datasets.rst added a link from a tutorial demonstrating detailed data prep steps. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * RNN-T confidence and alignment bugfix (#7381) * new frame_confidence and alignments lists are now always created after the while loop Signed-off-by: Aleksandr Laptev * tests added Signed-off-by: Aleksandr Laptev --------- Signed-off-by: Aleksandr Laptev Signed-off-by: Sasha Meister * Fix resume from checkpoint in exp_manager (#7424) (#7426) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * Fix checking of cuda/cpu device for inputs of Decoder (#7444) * Fix checking of cuda/cpu device for inputs of Decoder Signed-off-by: Robin Dong * Update tacotron2.py Signed-off-by: Jason --------- Signed-off-by: Robin Dong Signed-off-by: Jason Co-authored-by: Jason Signed-off-by: Sasha Meister * Fix failure of ljspeech's get_data.py (#7430) * Fix failure of ljspeech's get_data.py Signed-off-by: Robin Dong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * [TTS] Fix audio codec type checks (#7373) * [TTS] Fix audio codec type checks Signed-off-by: Ryan * [TTS] Fix audio codec tests Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Sasha Meister * [TTS] Add dataset to path of logged artifacts (#7462) * [TTS] Add dataset to path of logged artifacts Signed-off-by: Ryan * [TTS] Revert axis name back to Audio Frames Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Sasha Meister * Fix sft dataset truncation (#7464) * Add fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Automatic Lip Reading Recognition (ALR) - ASR/CV (Visual ASR) (#7330) * striding_conv1d_k5 and dw_striding_conv1d_k5 subsampling Signed-off-by: mburchi * transpose conv1d inputs Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: mburchi * Update subsampling.py change striding_conv1d_k5 to striding_conv1d Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * video manifest Signed-off-by: mburchi * add collection classes Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test_step_outputs Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * clean references Signed-off-by: mburchi * freeze unfreeze transcribe cv models Signed-off-by: mburchi * correct manifest get_full_path bug Signed-off-by: mburchi * update for PR Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * guard torchvision Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * _video_speech_collate_fn in cv/data/video_to_text.py Signed-off-by: mburchi * add self.out = None to asr subsampling Signed-off-by: mburchi * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv -> multimodal/speech_cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: mburchi Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Gitman Signed-off-by: Sasha Meister * HF StarCoder to NeMo conversion script (#7421) * Script to convert HF StarCoder checkpoint to NeMo Signed-off-by: Jan Lasek * StarCoder conversion test Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jan Lasek * Fix test Signed-off-by: Jan Lasek * Catch up with save_to changes Signed-off-by: Jan Lasek * Don't abbreviate args for clarity Signed-off-by: Jan Lasek * Configurable precision: BF16 vs FP32 Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Lasek Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * fix bug when loading dist ckpt in peft (#7452) Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Signed-off-by: Sasha Meister * Fix adding positional embeddings in-place in transformer module (#7440) Signed-off-by: Tamerlan Tabolov Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix (#7478) Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Sasha Meister * add sleep (#7498) (#7499) * add sleep * add sleep onto config instead * add comment --------- Signed-off-by: Gerald Shen Co-authored-by: Gerald Shen <119401249+gshennvm@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix exp manager check for sleep (#7503) (#7504) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Sasha Meister * bugfix: trainer.accelerator=auto from None. (#7492) (#7493) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Sasha Meister * [doc] fix broken link (#7481) Signed-off-by: Stas Bekman Signed-off-by: Sasha Meister * [TTS] Read audio as int32 to avoid flac read errors (#7477) * [TTS] Read audio as int32 to avoid flac read errors Signed-off-by: Ryan * [TTS] Add comment about read failures Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Sasha Meister * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS (#7409) * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS * Train 'AISHELL-3' dataset with multi-speakers Signed-off-by: Robin Dong * Update get_data.py update copyright header Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update get_data.py added a disclaimer Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add new configuration file for AISHELL3 with multispeaker of fastpitch Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * dllogger - log on rank 0 only (#7513) Signed-off-by: Stas Bekman Signed-off-by: Sasha Meister * Fix TTS FastPitch tutorial (#7494) (#7516) * Fix --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix get_dist() tensor dimension (#7506) (#7515) Signed-off-by: Jocelyn Huang Co-authored-by: Jocelyn Signed-off-by: Sasha Meister * bugfix: specify trainer.strategy=auto when devices=1 (#7509) (#7512) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Sasha Meister * fix (#7511) Signed-off-by: Abhinav Khattar Signed-off-by: Sasha Meister * [TTS] Fix FastPitch data prep tutorial (#7524) Signed-off-by: Ryan Signed-off-by: Sasha Meister * add italian tokenization (#7486) * add italian tokenization Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more ipa lexicon it Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix error deletion Signed-off-by: GiacomoLeoneMaria * add test Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Replace None strategy with auto in tutorial notebooks (#7521) (#7527) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * unpin setuptools (#7534) (#7535) Signed-off-by: fayejf <36722593+fayejf@users.noreply.github.com> Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> Signed-off-by: Sasha Meister * Update per.py - if __name__ == "__main__" removed (now metric can be imported); - removed excessive classes (like "Sample" and "Statistics"); - transition from pandas df to dict of dicts; - removed unnecessary "return"; - notation fixing; - reduced calculation time Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Create punctuation_rates.py Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Format fixing Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * added nemo.logging, header, docstrings, how to use Signed-off-by: Sasha Meister * Added asserions to rate_punctuation.py Signed-off-by: Sasha Meister * fix typo Signed-off-by: Sasha Meister * added function for import and call, docstrings Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * remove auto generated examples (#7510) * explicitly remove autogenerated examples for data parallel evaluation Signed-off-by: arendu * mark autogenrated and remove it for test Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Add the `strategy` argument to `MegatronGPTModel.generate()` (#7264) It is passed as an explicit argument rather than through `**strategy_args` so as to ensure someone cannot accidentally pass other arguments that would end up being ignored. It is a keyword-only argument to ensure that if in the future we want to update the signature to `**strategy_args`, we can do it without breaking code. Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue (#7531) (#7533) * fix none dataloader issue ptl2 * ptl2.0 logging fixes for rnnt_models --------- Signed-off-by: KunalDhawan Co-authored-by: Kunal Dhawan Co-authored-by: Nithin Rao Signed-off-by: Sasha Meister * gpus -> devices (#7542) (#7545) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Sasha Meister * Update FFMPEG version to fix issue with torchaudio (#7551) (#7553) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Sasha Meister * PEFT GPT & T5 Refactor (#7308) * initial implementation of add_adapters API * correct type hint * Add config in add_adapters for save and load (@author bobchen) * Remove AdapterConfig to avoid import error * Add AdaterConfig back and move adaptermixin to sft model * Add NLPSaveRestoreConnector as default in NLPModel.restore_from * Add restore_from_nemo_with_adapter and test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename t5 file and classes to be consistent with GPT * add t5 sft dataset * add support for single-file format with T5SFTDataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Various small changes to make T5 SFT work like GPT SFT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add adapter evaluation test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add MultiAdaterConfig for ia3 and fix builder issue * Make ptuning for T5SFTModel work using mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add IA3_Adapter for AdapterName * Add adapter name for ptuning and attention adapter * Make test script GPT/T5 agnostic * Add layer selection feature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Integrate adapter name and config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update gpt peft tuning script to new API * add t5 peft tuning script with new API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix IA3 layer selection issue * Override state_dict on SFT model instead of mixin * Add load adapter by adapter config * move peft config map away from example script * auto get config from nemo adapter * Move PEFTConfig to new file * fix ckpt save/load for t5 * name change: add_adapters -> add_adapter * variable name change * update t5 script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix t5 issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add weight tying * update gpt tuning script * PEFT-API proposal * Fix according to comments * update tuning scripts * move merge_cfg_with to mixin class since it applies to both gpt and t5 and requires the model class for restore * Add mcore_gpt support for NLPAdapterMixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * variable name change to distinguish "peft" and "adapter" * override `load_adapters` to support `add_adapter` name change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tuning and eval script for adapter save/load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Ptuning on first stage only * add lora tutorial for review * Fix layer selection for mcore * add landing page * fix resume training Signed-off-by: jasonwan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mcore condition in sharded_state_dict to make sft work * Update lora_tutorial.md First edit of this file for PEFT documentation for NeMO Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * rename Adapter to AttentionAdapter to avoid confusion in doc * Change load_adapters to load .nemo * add quick start guide * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add load_adapters with .ckpt * Remove setup_complete changes in load_adapters * update landing page * remove typo * Updated quick_start.md per Chen Cui Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * Add inference config merger and tutorial * Add doc string for NLPAdapterModelMixin and deprecated warning on MegatronGPTPEFTModel * add supported_methods.md and update other documentations * Update supported_methods.md minor updates. Signed-off-by: Adi Renduchintala * Update landing_page.md minor update. Signed-off-by: Adi Renduchintala * Modify doc string for NLPAdapterModelMixin * Add doc string add_adapters in NLPAdapterModelMixin * rename canonical adapters * remove mcore hard dependency * [PATCH] move microbatch calculator to nemo from apex * remove apex dependency in gpt and t5 sft models * remove apex dependency in gpt model * render doc strings * fix * Add missing virtual_tokens on ptuning * fix docstrings * update gpt-style model coverage in docs * update docstring * Remove pdb * add lightning_fabric to make docstring rendering work * Add Ptuning missing key * try docstring rendering * Fix ptuning issue * update gpt t5 peft tuning and eval scripts * typos * update eval config * fix bug relating to apex dependency removal * typo * make predict step behave the same as test step * make lora tutorial work in notebook * cosmetics * update yaml scripts * mcore_gpt attribute optional * typo * update eval scripts and fix T5 eval bugs * add NLPDDPStrategyNotebook and trainer builder logic to use it * update lora notebook to use new trainer builder * fix microbatch calculator bug for inference after training * Convert markdown files to RST and incorporate with doc * typo * revise language * remove extra cell * remove unnecessary inheritance * remove old tests * move layer selection default so logging messages make sense * remove `save_adapters` as adapter weights are saved automatically during training * initialize weights from a checkpoint instead of randomly * multiple fields can form a context (#7147) * list of context fields and flexible prompt template Signed-off-by: arendu * list of fields for context Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Fix bug Signed-off-by: Cheng-Ping Hsieh * Add multiple truncation fields and middle truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Compatible to old ckpt Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix tokenize detokenize issue Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove detokenization, add truncation augmentation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve comments Signed-off-by: Cheng-Ping Hsieh * Remove unused import Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert eos Signed-off-by: Cheng-Ping Hsieh * Add tokenizer space_sensitive attribute Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix error Signed-off-by: Cheng-Ping Hsieh * Fix erorr and use re Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Change assert logic Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Follow adi suggestion Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove merge function Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add example and comment Signed-off-by: Cheng-Ping Hsieh * Remove context_key and add comment Signed-off-by: Cheng-Ping Hsieh * Remove random truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix template none Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * revert config changes * remove accidental breakpoint * support TP>1 loading * infer adapter type from checkpoint in during eval * breakup add adapter * enable interpolation of train_ds and validation_ds * update metric calc script to conform to single-file eval format * remove extraneous print * update lora notebook for updated merge_inference_cfg * Update nlp_adapter_mixins.py variable name change Signed-off-by: Chen Cui * turn off grad scaler for PP to match old scripts * remove PEFTSaveRestoreConnector since functionality all covered by the new mixin class * remove resume_from_checkpoint check since covered in #7335 * revert changes made in eval config interpolation * more interpolation * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dup line Signed-off-by: Chen Cui * code style warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix config mistake Signed-off-by: Chen Cui * add copyright header Signed-off-by: Chen Cui * fix code check warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert changes to remove apex dependency (mixed apex+nemo microbatch calculator broke some CI tests) Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * consolidate peft and sft scripts Signed-off-by: Chen Cui * update CI tests Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * notebook branch points to main to prepare for merge Signed-off-by: Chen Cui * fix gpt and t5 validation with any metric other than loss Signed-off-by: Chen Cui * support pre-extracted checkpoints Signed-off-by: Chen Cui --------- Signed-off-by: jasonwan Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Chen Cui Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marc Romeyn Co-authored-by: jasonwan Co-authored-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Yuanzhe Dong Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Sasha Meister * fix a typo (#7496) Signed-off-by: BestJuly Signed-off-by: Sasha Meister * [TTS] remove curly braces from ${BRANCH} in jupyer notebook cell. (#7554) (#7560) * remove curly braces. * remove installation of pynini. --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Sasha Meister * add youtube embed url (#7570) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Sasha Meister * Remap speakers to continuous range of speaker_id for dataset AISHELL3 (#7536) * Remap speakers to continuous range of speaker_id for dataset AISHELL3 * Add new key/value pair to record raw speaker for AISHELL3 dataset Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * fix validation_step_outputs initialization for multi-dataloader (#7546) (#7572) * added correct validation_step_outputs initialization for mutli-dataloader * changed kernel for display * Update logic for validation and test step outputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert multidataloader changes in multilang ASR notebook --------- Signed-off-by: KunalDhawan Signed-off-by: smajumdar Co-authored-by: Kunal Dhawan Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Append output of val step to self.validation_step_outputs (#7530) (#7532) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * [TTS] fixed trainer's accelerator and strategy. (#7569) (#7574) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Sasha Meister * Append val/test output to instance variable in EncDecSpeakerLabelModel (#7562) (#7573) * Append val/test output to the instance variable in EncDecSpeakerLabelModel * Handle test case in evaluation_step * Replace type with isinstance --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix CustomProgressBar for resume (#7427) (#7522) * Fix CustomProgress Bar for resume and multiple epochs * Edit num_training_batches * Use max_steps as total for progress bar for resume * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * fix typos in nfa and speech enhancement tutorials (#7580) (#7583) Signed-off-by: Elena Rastorgueva Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Signed-off-by: Sasha Meister * Add strategy as ddp_find_unused_parameters_true for glue_benchmark.py (#7454) (#7461) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * update strategy (#7577) (#7578) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Sasha Meister * Fix typos (#7581) Signed-off-by: Sasha Meister * Change hifigan finetune strategy to ddp_find_unused_parameters_true (#7579) (#7584) * Change strategy to auto --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Sasha Meister * [BugFix] Add missing quotes for auto strategy in tutorial notebooks (#7541) (#7548) * Add missing quotes for auto strategy * Revert trainer.gpus to trainer.devices in Self_Supervised_Pre_Training.ipynb --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * added per tests Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * [PATCH] PEFT import mcore (#7393) * [PATCH] PEFT import mcore Signed-off-by: Jason Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jason Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * add build os key (#7596) (#7599) * add build os key * add tools * update to stable version --------- Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Sasha Meister * StarCoder SFT test + bump PyT NGC image to 23.09 (#7540) * Add SFT StarCoder test Signed-off-by: Jan Lasek * Remove _modify_config call as it is covered in load_from_nemo just below Signed-off-by: Jan Lasek * Test with pyt:23.09 container Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek Signed-off-by: Sasha Meister * defaults changed (#7600) * defaults changed Signed-off-by: arendu * typo Signed-off-by: arendu * update Signed-off-by: arendu --------- Signed-off-by: arendu Signed-off-by: Sasha Meister * add ItalianPhonemesTokenizer (#7587) * add ItalianPhonemesTokenizer Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Italian phonemes Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test Signed-off-by: GiacomoLeoneMaria --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * best ckpt fix (#7564) (#7588) Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Signed-off-by: Sasha Meister * rate_punctuation.py Fixed output manifest saving Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Fix tests Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * Add files via upload (#7598) specifies the branch Signed-off-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix validation in G2PModel and ThutmoseTaggerModel (#7597) (#7606) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Safeguard nemo_text_processing installation on ARM (#7485) * safeguard nemo_text_processing installing Signed-off-by: Jason * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update check Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Function name fixing Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Moving PER to speech_to_text_eval.py Added: - "use_per": PER metric computing; - "scores_per_sample": metrics computation sample by sample for wer/cer/punctuation rates; - "output_with_scores_filename": saving manifest with metrics Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * Update test_metrics.py Updated "punctuation_error_rate" function name Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Added use_per description Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * guard extra dependencies Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Write metrics to "output_filename" if "scores_per_sample=True" Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * scores_per_sample description Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix import guards Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Stats printing when HAVE_TABLUATE_AND_PANDAS=False Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * Bound transformers version in requirements (#7620) Signed-off-by: Abhishree Signed-off-by: Sasha Meister * fix llama2 70b lora tuning bug (#7622) * fix llama2 70b lora tuning bug Signed-off-by: Chen Cui * Update peft_config.py brackets Signed-off-by: Adi Renduchintala --------- Signed-off-by: Chen Cui Signed-off-by: Adi Renduchintala Co-authored-by: Adi Renduchintala Signed-off-by: Sasha Meister * Fix import error no module name model_utils (#7629) Signed-off-by: Mehadi Hasan Menon Signed-off-by: Sasha Meister * Delete examples/asr/rate_punctuation.py Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * Added use_per description Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * metric and variables name fixing Signed-off-by: Sasha Meister * Add else samples = None Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * add fc large ls models (#7641) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Koluguri Signed-off-by: Sasha Meister * bugfix: trainer.gpus, trainer.strategy, trainer.accelerator (#7621) (#7642) * [TTS] bugfix for Tacotron2 tutorial due to PTL 2.0 * trainer.gpus -> trainer.devices * fixed related tutorial bugs --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * fix ssl models ptl monitor val through logging (#7608) (#7614) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Co-authored-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix metrics for SE tutorial (#7604) (#7612) Signed-off-by: Ante Jukić Co-authored-by: anteju <108555623+anteju@users.noreply.github.com> Signed-off-by: Sasha Meister * Add ddp_find_unused_parameters=True and change accelerator to auto (#7623) (#7644) * Add ddp_find_unused_parameters=True and change acclerator to auto * Add ddp_find_unused_parameters True for normalization_as_tagging_train.py --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix py3.11 dataclasses issue (#7616) * Fix py3.11 dataclasses issue (#7582) * Update ASR configs to support Python 3.11 Signe… * conversion issue fix (#7648) (#7668) Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> * layernorm1p fix (#7523) (#7567) * layernorm1p fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add layernorm1p to if statement * config changes * gpt config changes * remove layernorm_zero_centered_gamma from gpt config * change line --------- Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * generalized chat sft prompt (#7655) * fix dataset issues Signed-off-by: Yi Dong * working version Signed-off-by: Yi Dong * all passed Signed-off-by: Yi Dong * refactor tests Signed-off-by: Yi Dong * all pass Signed-off-by: Yi Dong * working version Signed-off-by: Yi Dong * use end name signal for labels Signed-off-by: Yi Dong * all fixed Signed-off-by: Yi Dong * update doc Signed-off-by: Yi Dong * style fix Signed-off-by: Yi Dong * remove unused imports Signed-off-by: Yi Dong * make sure nccl not timing out Signed-off-by: Yi Dong * style fix Signed-off-by: Yi Dong * generate example template Signed-off-by: Yi Dong * generic end of name token Signed-off-by: Yi Dong * style fix Signed-off-by: Yi Dong * add the chat prompt format into the config Signed-off-by: Yi Dong * make sure sft working Signed-off-by: Yi Dong * address reviewer comment Signed-off-by: Yi Dong * fix non Signed-off-by: Yi Dong * try openAI prompt Signed-off-by: Yi Dong * remove unused imports Signed-off-by: Yi Dong * remove human labels from the data Signed-off-by: Yi Dong * use hf dataset to clean Signed-off-by: Yi Dong * reviewer comments Signed-off-by: Yi Dong --------- Signed-off-by: Yi Dong * Fix vad & speech command tutorial - onnx (#7671) (#7672) * fix vad onnx * fix mbn onnx --------- Signed-off-by: fayejf Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> * Fix in the confidence ensemble test (#7682) * Fix in the confidence ensemble test Signed-off-by: Igor Gitman * Correct parameter names Signed-off-by: Igor Gitman --------- Signed-off-by: Igor Gitman * PEFT eval fix (#7626) (#7638) * fix issue where peft weights are not loaded for distributed checkpoints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Chen Cui Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * propagate mp config (#7637) (#7639) Signed-off-by: eharper Co-authored-by: Eric Harper * Add find_unused_parameters_true for text_classiftn and punctuation_capitalization (#7649) (#7657) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * Hotfix (#7501) (#7568) Signed-off-by: Jan Baczek Co-authored-by: jbaczek <45043825+jbaczek@users.noreply.github.com> * Avoid duplicated checkpoint save (#7555) (#7566) Signed-off-by: Mikołaj Błaż Co-authored-by: mikolajblaz * Cache FP8 weight and transpose only at the first micro-batch in each validation and test routine (#7470) (#7483) * Cache weight and transpose only in the first batch in all training, val, and test runs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Add an option to disable manual GC in validation (#7467) (#7476) Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym * Remove PUBLICATIONS.md, point to github.io NeMo page instead (#7694) (#7695) * update publications section to point to blog website page * add hyphen * use double backquotes for code formatting --------- Signed-off-by: Elena Rastorgueva Signed-off-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> * Fix multi rank finetune for ASR (#7684) (#7699) * Fix multi rank finetune for ASR * Actually add time * Actually add time --------- Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar * Update docs: readme, getting started, ASR intro (#7679) * [TTS] Add dataset to path of logged artifacts (#7462) * [TTS] Add dataset to path of logged artifacts Signed-off-by: Ryan * [TTS] Revert axis name back to Audio Frames Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Elena Rastorgueva * move install info to INSTALLATION.md Signed-off-by: Elena Rastorgueva * tidy up links Signed-off-by: Elena Rastorgueva * Fix sft dataset truncation (#7464) * Add fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Automatic Lip Reading Recognition (ALR) - ASR/CV (Visual ASR) (#7330) * striding_conv1d_k5 and dw_striding_conv1d_k5 subsampling Signed-off-by: mburchi * transpose conv1d inputs Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: mburchi * Update subsampling.py change striding_conv1d_k5 to striding_conv1d Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * video manifest Signed-off-by: mburchi * add collection classes Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test_step_outputs Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * correct manifest bug when having only audio or only videos Signed-off-by: mburchi * clean references Signed-off-by: mburchi * freeze unfreeze transcribe cv models Signed-off-by: mburchi * correct manifest get_full_path bug Signed-off-by: mburchi * update for PR Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * guard torchvision Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * _video_speech_collate_fn in cv/data/video_to_text.py Signed-off-by: mburchi * add self.out = None to asr subsampling Signed-off-by: mburchi * Update nemo/collections/cv/data/video_to_text_dataset.py Co-authored-by: Igor Gitman Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv -> multimodal/speech_cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: mburchi Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Igor Gitman Signed-off-by: Elena Rastorgueva * HF StarCoder to NeMo conversion script (#7421) * Script to convert HF StarCoder checkpoint to NeMo Signed-off-by: Jan Lasek * StarCoder conversion test Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Jan Lasek * Fix test Signed-off-by: Jan Lasek * Catch up with save_to changes Signed-off-by: Jan Lasek * Don't abbreviate args for clarity Signed-off-by: Jan Lasek * Configurable precision: BF16 vs FP32 Signed-off-by: Jan Lasek * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Lasek Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * fix bug when loading dist ckpt in peft (#7452) Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Signed-off-by: Elena Rastorgueva * Fix adding positional embeddings in-place in transformer module (#7440) Signed-off-by: Tamerlan Tabolov Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix (#7478) Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Elena Rastorgueva * add sleep (#7498) (#7499) * add sleep * add sleep onto config instead * add comment --------- Signed-off-by: Gerald Shen Co-authored-by: Gerald Shen <119401249+gshennvm@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix exp manager check for sleep (#7503) (#7504) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Elena Rastorgueva * bugfix: trainer.accelerator=auto from None. (#7492) (#7493) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Elena Rastorgueva * [doc] fix broken link (#7481) Signed-off-by: Stas Bekman Signed-off-by: Elena Rastorgueva * [TTS] Read audio as int32 to avoid flac read errors (#7477) * [TTS] Read audio as int32 to avoid flac read errors Signed-off-by: Ryan * [TTS] Add comment about read failures Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Elena Rastorgueva * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS (#7409) * Add dataset 'AISHELL-3' from OpenSLR for training mandarin TTS * Train 'AISHELL-3' dataset with multi-speakers Signed-off-by: Robin Dong * Update get_data.py update copyright header Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Update get_data.py added a disclaimer Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add new configuration file for AISHELL3 with multispeaker of fastpitch Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * dllogger - log on rank 0 only (#7513) Signed-off-by: Stas Bekman Signed-off-by: Elena Rastorgueva * Fix TTS FastPitch tutorial (#7494) (#7516) * Fix --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix get_dist() tensor dimension (#7506) (#7515) Signed-off-by: Jocelyn Huang Co-authored-by: Jocelyn Signed-off-by: Elena Rastorgueva * bugfix: specify trainer.strategy=auto when devices=1 (#7509) (#7512) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Elena Rastorgueva * fix (#7511) Signed-off-by: Abhinav Khattar Signed-off-by: Elena Rastorgueva * [TTS] Fix FastPitch data prep tutorial (#7524) Signed-off-by: Ryan Signed-off-by: Elena Rastorgueva * add italian tokenization (#7486) * add italian tokenization Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more ipa lexicon it Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix error deletion Signed-off-by: GiacomoLeoneMaria * add test Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Replace None strategy with auto in tutorial notebooks (#7521) (#7527) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * unpin setuptools (#7534) (#7535) Signed-off-by: fayejf <36722593+fayejf@users.noreply.github.com> Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * remove auto generated examples (#7510) * explicitly remove autogenerated examples for data parallel evaluation Signed-off-by: arendu * mark autogenrated and remove it for test Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Add the `strategy` argument to `MegatronGPTModel.generate()` (#7264) It is passed as an explicit argument rather than through `**strategy_args` so as to ensure someone cannot accidentally pass other arguments that would end up being ignored. It is a keyword-only argument to ensure that if in the future we want to update the signature to `**strategy_args`, we can do it without breaking code. Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue (#7531) (#7533) * fix none dataloader issue ptl2 * ptl2.0 logging fixes for rnnt_models --------- Signed-off-by: KunalDhawan Co-authored-by: Kunal Dhawan Co-authored-by: Nithin Rao Signed-off-by: Elena Rastorgueva * gpus -> devices (#7542) (#7545) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Elena Rastorgueva * Update FFMPEG version to fix issue with torchaudio (#7551) (#7553) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Elena Rastorgueva * PEFT GPT & T5 Refactor (#7308) * initial implementation of add_adapters API * correct type hint * Add config in add_adapters for save and load (@author bobchen) * Remove AdapterConfig to avoid import error * Add AdaterConfig back and move adaptermixin to sft model * Add NLPSaveRestoreConnector as default in NLPModel.restore_from * Add restore_from_nemo_with_adapter and test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rename t5 file and classes to be consistent with GPT * add t5 sft dataset * add support for single-file format with T5SFTDataset * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Various small changes to make T5 SFT work like GPT SFT * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add adapter evaluation test script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add MultiAdaterConfig for ia3 and fix builder issue * Make ptuning for T5SFTModel work using mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add IA3_Adapter for AdapterName * Add adapter name for ptuning and attention adapter * Make test script GPT/T5 agnostic * Add layer selection feature * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Integrate adapter name and config * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update gpt peft tuning script to new API * add t5 peft tuning script with new API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix IA3 layer selection issue * Override state_dict on SFT model instead of mixin * Add load adapter by adapter config * move peft config map away from example script * auto get config from nemo adapter * Move PEFTConfig to new file * fix ckpt save/load for t5 * name change: add_adapters -> add_adapter * variable name change * update t5 script * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix t5 issues * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add weight tying * update gpt tuning script * PEFT-API proposal * Fix according to comments * update tuning scripts * move merge_cfg_with to mixin class since it applies to both gpt and t5 and requires the model class for restore * Add mcore_gpt support for NLPAdapterMixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo * variable name change to distinguish "peft" and "adapter" * override `load_adapters` to support `add_adapter` name change * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tuning and eval script for adapter save/load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add Ptuning on first stage only * add lora tutorial for review * Fix layer selection for mcore * add landing page * fix resume training Signed-off-by: jasonwan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add mcore condition in sharded_state_dict to make sft work * Update lora_tutorial.md First edit of this file for PEFT documentation for NeMO Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * rename Adapter to AttentionAdapter to avoid confusion in doc * Change load_adapters to load .nemo * add quick start guide * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add load_adapters with .ckpt * Remove setup_complete changes in load_adapters * update landing page * remove typo * Updated quick_start.md per Chen Cui Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> * Add inference config merger and tutorial * Add doc string for NLPAdapterModelMixin and deprecated warning on MegatronGPTPEFTModel * add supported_methods.md and update other documentations * Update supported_methods.md minor updates. Signed-off-by: Adi Renduchintala * Update landing_page.md minor update. Signed-off-by: Adi Renduchintala * Modify doc string for NLPAdapterModelMixin * Add doc string add_adapters in NLPAdapterModelMixin * rename canonical adapters * remove mcore hard dependency * [PATCH] move microbatch calculator to nemo from apex * remove apex dependency in gpt and t5 sft models * remove apex dependency in gpt model * render doc strings * fix * Add missing virtual_tokens on ptuning * fix docstrings * update gpt-style model coverage in docs * update docstring * Remove pdb * add lightning_fabric to make docstring rendering work * Add Ptuning missing key * try docstring rendering * Fix ptuning issue * update gpt t5 peft tuning and eval scripts * typos * update eval config * fix bug relating to apex dependency removal * typo * make predict step behave the same as test step * make lora tutorial work in notebook * cosmetics * update yaml scripts * mcore_gpt attribute optional * typo * update eval scripts and fix T5 eval bugs * add NLPDDPStrategyNotebook and trainer builder logic to use it * update lora notebook to use new trainer builder * fix microbatch calculator bug for inference after training * Convert markdown files to RST and incorporate with doc * typo * revise language * remove extra cell * remove unnecessary inheritance * remove old tests * move layer selection default so logging messages make sense * remove `save_adapters` as adapter weights are saved automatically during training * initialize weights from a checkpoint instead of randomly * multiple fields can form a context (#7147) * list of context fields and flexible prompt template Signed-off-by: arendu * list of fields for context Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Fix bug Signed-off-by: Cheng-Ping Hsieh * Add multiple truncation fields and middle truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Compatible to old ckpt Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix tokenize detokenize issue Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove detokenization, add truncation augmentation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolve comments Signed-off-by: Cheng-Ping Hsieh * Remove unused import Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert eos Signed-off-by: Cheng-Ping Hsieh * Add tokenizer space_sensitive attribute Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix error Signed-off-by: Cheng-Ping Hsieh * Fix erorr and use re Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * Change assert logic Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Follow adi suggestion Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove merge function Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add example and comment Signed-off-by: Cheng-Ping Hsieh * Remove context_key and add comment Signed-off-by: Cheng-Ping Hsieh * Remove random truncation Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix template none Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> * revert config changes * remove accidental breakpoint * support TP>1 loading * infer adapter type from checkpoint in during eval * breakup add adapter * enable interpolation of train_ds and validation_ds * update metric calc script to conform to single-file eval format * remove extraneous print * update lora notebook for updated merge_inference_cfg * Update nlp_adapter_mixins.py variable name change Signed-off-by: Chen Cui * turn off grad scaler for PP to match old scripts * remove PEFTSaveRestoreConnector since functionality all covered by the new mixin class * remove resume_from_checkpoint check since covered in #7335 * revert changes made in eval config interpolation * more interpolation * typo * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove dup line Signed-off-by: Chen Cui * code style warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix config mistake Signed-off-by: Chen Cui * add copyright header Signed-off-by: Chen Cui * fix code check warnings Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert changes to remove apex dependency (mixed apex+nemo microbatch calculator broke some CI tests) Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * update deprecation notices Signed-off-by: Chen Cui * consolidate peft and sft scripts Signed-off-by: Chen Cui * update CI tests Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * notebook branch points to main to prepare for merge Signed-off-by: Chen Cui * fix gpt and t5 validation with any metric other than loss Signed-off-by: Chen Cui * support pre-extracted checkpoints Signed-off-by: Chen Cui --------- Signed-off-by: jasonwan Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Chen Cui Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marc Romeyn Co-authored-by: jasonwan Co-authored-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Yuanzhe Dong Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * fix a typo (#7496) Signed-off-by: BestJuly Signed-off-by: Elena Rastorgueva * [TTS] remove curly braces from ${BRANCH} in jupyer notebook cell. (#7554) (#7560) * remove curly braces. * remove installation of pynini. --------- Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Elena Rastorgueva * add youtube embed url (#7570) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Elena Rastorgueva * Remap speakers to continuous range of speaker_id for dataset AISHELL3 (#7536) * Remap speakers to continuous range of speaker_id for dataset AISHELL3 * Add new key/value pair to record raw speaker for AISHELL3 dataset Signed-off-by: Robin Dong --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * fix validation_step_outputs initialization for multi-dataloader (#7546) (#7572) * added correct validation_step_outputs initialization for mutli-dataloader * changed kernel for display * Update logic for validation and test step outputs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert multidataloader changes in multilang ASR notebook --------- Signed-off-by: KunalDhawan Signed-off-by: smajumdar Co-authored-by: Kunal Dhawan Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Append output of val step to self.validation_step_outputs (#7530) (#7532) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * [TTS] fixed trainer's accelerator and strategy. (#7569) (#7574) Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Elena Rastorgueva * Append val/test output to instance variable in EncDecSpeakerLabelModel (#7562) (#7573) * Append val/test output to the instance variable in EncDecSpeakerLabelModel * Handle test case in evaluation_step * Replace type with isinstance --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix CustomProgressBar for resume (#7427) (#7522) * Fix CustomProgress Bar for resume and multiple epochs * Edit num_training_batches * Use max_steps as total for progress bar for resume * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * fix typos in nfa and speech enhancement tutorials (#7580) (#7583) Signed-off-by: Elena Rastorgueva Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Add strategy as ddp_find_unused_parameters_true for glue_benchmark.py (#7454) (#7461) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * update strategy (#7577) (#7578) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Elena Rastorgueva * Fix typos (#7581) Signed-off-by: Elena Rastorgueva * Change hifigan finetune strategy to ddp_find_unused_parameters_true (#7579) (#7584) * Change strategy to auto --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * [BugFix] Add missing quotes for auto strategy in tutorial notebooks (#7541) (#7548) * Add missing quotes for auto strategy * Revert trainer.gpus to trainer.devices in Self_Supervised_Pre_Training.ipynb --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * add build os key (#7596) (#7599) * add build os key * add tools * update to stable version --------- Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Signed-off-by: Elena Rastorgueva * StarCoder SFT test + bump PyT NGC image to 23.09 (#7540) * Add SFT StarCoder test Signed-off-by: Jan Lasek * Remove _modify_config call as it is covered in load_from_nemo just below Signed-off-by: Jan Lasek * Test with pyt:23.09 container Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek Signed-off-by: Elena Rastorgueva * defaults changed (#7600) * defaults changed Signed-off-by: arendu * typo Signed-off-by: arendu * update Signed-off-by: arendu --------- Signed-off-by: arendu Signed-off-by: Elena Rastorgueva * add ItalianPhonemesTokenizer (#7587) * add ItalianPhonemesTokenizer Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Italian phonemes Signed-off-by: GiacomoLeoneMaria * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test Signed-off-by: GiacomoLeoneMaria --------- Signed-off-by: GiacomoLeoneMaria Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * best ckpt fix (#7564) (#7588) Signed-off-by: dimapihtar Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Add files via upload (#7598) specifies the branch Signed-off-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix validation in G2PModel and ThutmoseTaggerModel (#7597) (#7606) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Safeguard nemo_text_processing installation on ARM (#7485) * safeguard nemo_text_processing installing Signed-off-by: Jason * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update check Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Bound transformers version in requirements (#7620) Signed-off-by: Abhishree Signed-off-by: Elena Rastorgueva * fix llama2 70b lora tuning bug (#7622) * fix llama2 70b lora tuning bug Signed-off-by: Chen Cui * Update peft_config.py brackets Signed-off-by: Adi Renduchintala --------- Signed-off-by: Chen Cui Signed-off-by: Adi Renduchintala Co-authored-by: Adi Renduchintala Signed-off-by: Elena Rastorgueva * Fix import error no module name model_utils (#7629) Signed-off-by: Mehadi Hasan Menon Signed-off-by: Elena Rastorgueva * add fc large ls models (#7641) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Koluguri Signed-off-by: Elena Rastorgueva * bugfix: trainer.gpus, trainer.strategy, trainer.accelerator (#7621) (#7642) * [TTS] bugfix for Tacotron2 tutorial due to PTL 2.0 * trainer.gpus -> trainer.devices * fixed related tutorial bugs --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * fix ssl models ptl monitor val through logging (#7608) (#7614) Signed-off-by: Nithin Rao Koluguri Co-authored-by: Nithin Rao Co-authored-by: Eric Harper Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix metrics for SE tutorial (#7604) (#7612) Signed-off-by: Ante Jukić Co-authored-by: anteju <108555623+anteju@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Add ddp_find_unused_parameters=True and change accelerator to auto (#7623) (#7644) * Add ddp_find_unused_parameters=True and change acclerator to auto * Add ddp_find_unused_parameters True for normalization_as_tagging_train.py --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * Fix py3.11 dataclasses issue (#7616) * Fix py3.11 dataclasses issue (#7582) * Update ASR configs to support Python 3.11 Signed-off-by: smajumdar * Update TTS configs to support Python 3.11 Signed-off-by: smajumdar * Guard MeCab and Ipadic Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix remaining ASR dataclasses Signed-off-by: smajumdar * Fix remaining ASR dataclasses Signed-off-by: smajumdar * Fix scripts Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Update name to ConfidenceMethodConfig Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain (#7576) (#7586) * Broadcast loss only when using pipeline parallelism and within the pipeline parallel domain * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym Co-authored-by: Sangkug Lym Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Safeguard nemo_text_processing installation on ARM (#7485) * safeguard nemo_text_processing installing Signed-off-by: Jason * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update check Signed-off-by: Jason --------- Signed-off-by: Jason Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Fix changes to confidence measure Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Signed-off-by: Sangkug Lym Signed-off-by: Jason Co-authored-by: Somshubra Majumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Sangkug Lym Co-authored-by: Jason Signed-off-by: Elena Rastorgueva * Fix issues with Dockerfile (#7650) (#7652) Signed-off-by: smajumdar Co-authored-by: Somshubra Majumdar Signed-off-by: Elena Rastorgueva * [ASR] RNN-T greedy decoding max_frames fix for alignment and confidence (#7635) * decoding and test fix Signed-off-by: Aleksandr Laptev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * [ASR] Fix type error in jasper (#7636) (#7653) Signed-off-by: Ryan Co-authored-by: Ryan Langman Signed-off-by: Elena Rastorgueva * [TTS] Add STFT and SI-SDR loss to audio codec recipe (#7468) * [TTS] Add STFT and SI-SDR loss to audio codec recipe Signed-off-by: Ryan * [TTS] Fix STFT resolution Signed-off-by: Ryan * [TTS] Fix training metric logging Signed-off-by: Ryan * [TTS] Add docstring to mel and stft losses Signed-off-by: Ryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ryan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Elena Rastorgueva * add outline of asr quickstart info to asr/intro.rst Signed-off-by: Elena Rastorgueva * add CLI, LM and real-time transcription sections Signed-off-by: Elena Rastorgueva * Create per.py (#7538) * Move model precision copy (#7336) * move cfg precision set to megatron base model Signed-off-by: Maanu Grover * remove copy from other models Signed-off-by: Maanu Grover * modify attribute not arg Signed-off-by: Maanu Grover * fix gpt model test for ptl 2.0 Signed-off-by: Maanu Grover * rename function and add docstring Signed-off-by: Maanu Grover * replace precision to dtype conditionals with func call Signed-off-by: Maanu Grover * unnecessary function and cfg reset Signed-off-by: Maanu Grover * set default value Signed-off-by: Maanu Grover * fix precision lookup in a few more places Signed-off-by: Maanu Grover * rename mapping function Signed-off-by: Maanu Grover * ununsed import Signed-off-by: Maanu Grover * save torch datatype to model Signed-off-by: Maanu Grover * set weights precision wrt amp o2 Signed-off-by: Maanu Grover * Revert "set weights precision wrt amp o2" This reverts commit 313a4bfe5eb69d771a6d2433898c0685836aef5c. Signed-off-by: Maanu Grover * revert half precision at inference attempt Signed-off-by: Maanu Grover * move autocast dtype to base model Signed-off-by: Maanu Grover * move params dtype to base model, enable fp16 O2 inf Signed-off-by: Maanu Grover * unused imports Signed-off-by: Maanu Grover --------- Signed-off-by: Maanu Grover Signed-off-by: Sasha Meister * Fix PEFT checkpoint loading (#7388) * Fix PEFT checkpoint loading Signed-off-by: Jason Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jason Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Use distributed optimizer support for multiple dtypes (#7359) * Update distopt wrapper with multiple dtype support Remove manual handling of separate FP32 optimizer. Signed-off-by: Tim Moon * Use distopt support for contiguous buffers with multiple dtypes Signed-off-by: Tim Moon * Fix typo Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Separate distopt buckets for first GPT layer and non-overlapped params Signed-off-by: Tim Moon * Add distopt logic for int dtypes Signed-off-by: Tim Moon * Update Apex commit Signed-off-by: Tim Moon * Remove unused variables Signed-off-by: Tim Moon * Update Apex commit in README and Jenkensfile Signed-off-by: Tim Moon * Debug Dockerfile and Jenkinsfile Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * minor fix for llama ckpt conversion script (#7387) * minor fix for llama ckpt conversion script Signed-off-by: Jason Wang * Update Jenkinsfile Signed-off-by: Jason Wang * remove fast_swiglu configuration Signed-off-by: Jason Wang --------- Signed-off-by: Jason Wang Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * Fix wrong calling of librosa.get_duration() in notebook (#7376) Signed-off-by: Robin Dong Co-authored-by: Somshubra Majumdar Signed-off-by: Sasha Meister * [PATCH] PEFT import mcore (#7393) * [PATCH] PEFT import mcore Signed-off-by: Jason Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jason Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Create per.py Script for calculation Punctuation Error Rate and related rates (correct rate, deletions rate, etc.) Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: Sasha Meister * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Sasha Meister * [TTS] Added a callback for logging initial data (#7384) Signed-off-by: Ante Jukić Signed-off-by: Sasha Meister * Update Core Commit (#7402) * Update Core Commit Signed-off-by: Abhinav Khattar * update commit Signed-off-by: Abhinav Khattar --------- Signed-off-by: Abhinav Khattar Signed-off-by: Sasha Meister * Use cfg attribute in bert (#7394) * use cfg attribute instead of arg Signed-off-by: Maanu Grover * use torch_dtype in place of cfg.precision Signed-off-by: Maanu Grover * move precision copy before super constructor Signed-off-by: Maanu Grover * use trainer arg Signed-off-by: Maanu Grover --------- Signed-off-by: Maanu Grover Signed-off-by: Sasha Meister * Add support for bias conversion in Swiglu models (#7386) * Add support for bias conversion in Swiglu models Signed-off-by: smajumdar * Add support for auto extracting tokenizer model Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for auto extracting tokenizer model Signed-off-by: smajumdar * Fix issue with missing tokenizer Signed-off-by: smajumdar * Refactor Signed-off-by: smajumdar * Refactor Signed-off-by: smajumdar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: smajumdar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Update save_to and restore_from for dist checkpointing (#7343) * add dist ckpt to save to, in progress Signed-off-by: eharper * move dist ckpt Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean up Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update restore from, need to figure out how to initialize distributed Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * launch distrib if needed when restoring dist ckpt Signed-off-by: eharper * when using mcore we can change tp pp on the fly Signed-off-by: eharper * add load_from_checkpoint support for dist ckpt Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update llama convert script to save dist .nemo Signed-off-by: eharper * fix load dist ckpt Signed-off-by: jasonwan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * setup TE TP groups if needed Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * setup te tp groups if needed Signed-off-by: eharper * remove import Signed-off-by: eharper --------- Signed-off-by: eharper Signed-off-by: jasonwan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: jasonwan Signed-off-by: Sasha Meister * fix forward for with mcore=false (#7403) Signed-off-by: Jimmy Zhang Co-authored-by: Jimmy Zhang Signed-off-by: Sasha Meister * Fix logging to remove 's/it' from progress bar in Megatron models and add train_step_timing (#7374) * Add CustomProgressBar class to exp_manager and trainer callbacks Signed-off-by: Abhishree * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the progress bar to reflect total microbatch cnt Signed-off-by: Abhishree * Modify CustomProgressBar class 1) Modify CustomProgressBar class to update progress bar per global_step instead of per microbatch 2) Add the callback to other megatron training/finetuning files that are not using MegatronTrainerBuilder Signed-off-by: Abhishree * Add CustomProgressBar callback to tuning files Signed-off-by: Abhishree * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Set Activation Checkpointing Defaults (#7404) * Set Activation Checkpointing Defaults Signed-off-by: Abhinav Khattar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * check for None Signed-off-by: Abhinav Khattar * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhinav Khattar Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * make loss mask default to false (#7407) Signed-off-by: eharper Signed-off-by: Sasha Meister * Add dummy userbuffer config files (#7408) Signed-off-by: Sangkug Lym Signed-off-by: Sasha Meister * add missing ubconf files (#7412) Signed-off-by: Abhinav Khattar Signed-off-by: Sasha Meister * New tutorial on Speech Data Explorer (#7405) * Added Google Colab based tutorial on Speech Data Explorer Signed-off-by: George Zelenfroynd Signed-off-by: Sasha Meister * Update ptl training ckpt conversion script to work with dist ckpt (#7416) * update ptl convert script Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * don't break legacy Signed-off-by: eharper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: eharper Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Allow disabling sanity checking when num_sanity_val_steps=0 (#7413) * Allow disabling sanity checking when num_sanity_val_steps=0 Signed-off-by: Abhishree * Update num_sanity_val_steps to be a multiple of num_microbatches Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Abhishree Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Add comprehensive error messages (#7261) Signed-off-by: Anton Peganov Signed-off-by: Sasha Meister * check NEMO_PATH (#7418) Signed-off-by: Nikolay Karpov Signed-off-by: Sasha Meister * layer selection for ia3 (#7417) * layer selection for ia3 Signed-off-by: arendu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: arendu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Fix missing pip package 'einops' (#7397) Signed-off-by: Robin Dong Signed-off-by: Sasha Meister * Fix failure of pyaudio in Google Colab (#7396) Signed-off-by: Robin Dong Signed-off-by: Sasha Meister * Update README.md: output_path --> output_manifest_filepath (#7442) Signed-off-by: Samuele Cornell Signed-off-by: Sasha Meister * Add rope dynamic linear scaling (#7437) * Add dynamic linear scaling Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yang Zhang Signed-off-by: Sasha Meister * Fix None dataloader issue in PTL2.0 (#7455) * Fix None dataloader issue in PTL2.0 Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * updating values of self._validation_dl and self._test_dl as well Signed-off-by: KunalDhawan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: KunalDhawan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * [ASR] Confidence measure -> method renames (#7434) * measure -> method Signed-off-by: Aleksandr Laptev * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aleksandr Laptev Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Add steps for document of getting dataset 'SF Bilingual Speech' (#7378) * Add steps for document of getting dataset 'SF Bilingual Speech' Signed-off-by: Robin Dong * Update datasets.rst added a link from a tutorial demonstrating detailed data prep steps. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Sasha Meister * RNN-T confidence and alignment bugfix (#7381) * new frame_confidence and alignments lists are now always created after the while loop Signed-off-by: Aleksandr Laptev * tests added Signed-off-by: Aleksandr Laptev --------- Signed-off-by: Aleksandr Laptev Signed-off-by: Sasha Meister * Fix resume from checkpoint in exp_manager (#7424) (#7426) Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: Sasha Meister * Fix checking of cuda/cpu device for inputs of Decoder (#7444) * Fix checking of cuda/cpu device for inputs of Decoder Signed-off-by: Robin Dong * Update tacotron2.py Signed-off-by: Jason --------- Signed-off-by: Robin Dong Signed-off-by: Jason Co-authored-by: Jason Signed-off-by: Sasha Meister * Fix failure of ljspeech's get_data.py (#7430) * Fix failure of ljspeech's get_data.py Signed-off-by: Robin Dong * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Robin Dong Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * [TTS] Fix audio codec type checks (#7373) * [TTS] Fix audio codec type checks Signed-off-by: Ryan * [TTS] Fix audio codec tests Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Sasha Meister * [TTS] Add dataset to path of logged artifacts (#7462) * [TTS] Add dataset to path of logged artifacts Signed-off-by: Ryan * [TTS] Revert axis name back to Audio Frames Signed-off-by: Ryan --------- Signed-off-by: Ryan Signed-off-by: Sasha Meister * Fix sft dataset truncation (#7464) * Add fix Signed-off-by: Cheng-Ping Hsieh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Cheng-Ping Hsieh --------- Signed-off-by: Cheng-Ping Hsieh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister * Automatic Lip Reading Recognition (ALR) - ASR/CV (Visual ASR) (#7330) * striding_conv1d_k5 and dw_striding_conv1d_k5 subsampling Signed-off-by: mburchi * transpose conv1d inputs Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: mburchi * Update subsampling.py change striding_conv1d_k5 to striding_conv1d Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> * cv branch Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * video manifest Signed-off-by: mburchi * add collection classes Signed-off-by: mburchi * [pre-commit.ci] auto fixes from pre-… * fix onnx (#7703) (#7704) Signed-off-by: fayejf Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> * move core install to /workspace (#7706) Signed-off-by: Abhinav Khattar * Fix typo in audio codec config, encoder target (#7697) Signed-off-by: Ante Jukić * Replace strategy='dp'/None with 'auto' (#7681) (#7696) * Add strategy=auto for None and dp * Change strategy from None to auto --------- Signed-off-by: Abhishree Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> * [ASR] Multichannel mask estimator with flex number of channels (#7317) * Adding a mask estimator which can process an arbitrary number of channels Signed-off-by: Ante Jukić * Bypass failing tests + mark as pleasefixme Signed-off-by: Ante Jukić --------- Signed-off-by: Ante Jukić * fix ptl_bugs in slu_models.py (#7689) (#7712) * fix ptl_bugs in slu_models.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change strategy to ddp_find_unused_parameters_true in slu example yaml --------- Signed-off-by: Seonghun Noh Signed-off-by: Seonghun Co-authored-by: Seonghun Noh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> * fix code block typo (#7717) Signed-off-by: Elena Rastorgueva * Update key mapping logic * Few merge fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix diff for non-mm models * Fix diff for non-mm models * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove deployment and export scripts * Improve the unet ckpt loading logic. * Improve the unet ckpt loading logic. * Add checkpoint_averaging script * Hide multimodal code changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Eric's comments * Revert "Hide multimodal code changes" This reverts commit d6900f9bc1922d086e2e388dcec6e3bd2b0f59dc. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix configs * Fix neva model * Fix neva casting * Fix neva LoRA non MCore version * Fix neva LoRA MCore * [SD] group norm fixes * Fix neva cfg merge * remove groupnorm dependency * Fix copyright headers * LLaVA 1_5 and LORA update * Fix logs * Fix neva mcore infernece * Fix ema * Fix ema * Address Somshubra comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix NeVA * Remove llama tricks since we are padding the embedding weights directly now * Update Dockerfile and mm requirements * Multimodal unit and jenkins tests * Add Multimodal Docs * update default conv_template * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix neva evaluation * Update Dockerfile * Fix evaluation loading * Fix evaluation API * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change quick-gelu to approx-gelu * hide multimodal * Revert "hide multimodal" This reverts commit e2ccc8850cd0f939f48f10120580697f23e89ca1. * REstructure * REstructure again * Update neva evalution code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix neva model after merging * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Restructure * Restructure, rename * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Restructure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove package requirement * hide docs and artifacts * Rename Nerf * Hide Nerf and text to image * Update examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py Co-authored-by: Eric Harper Signed-off-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> * Update examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py Co-authored-by: Eric Harper Signed-off-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> * Fix PR comments, clean comments, move to torch_dtype_from_precision * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update to torch_dtype_from_precision * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix PR comments * Fix copyright and docstrings * Update docstrings * Optimize imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "Hide Nerf and text to image" This reverts commit 782316f3ac3e4e8704d417dd9aa4f1068ba20bbe. * Add copyright information * Optimize imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Optimize imports * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address comments * Bug fix due to restructure * remove color map detector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * copyright * address docstring formatting and removed unused parts for SD * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Copyright * torchvision guard * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * code scan clean --------- Signed-off-by: Samuele Cornell Signed-off-by: Cheng-Ping Hsieh Signed-off-by: KunalDhawan Signed-off-by: Aleksandr Laptev Signed-off-by: Robin Dong Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Abhishree Signed-off-by: Jason Signed-off-by: Ryan Signed-off-by: mburchi Signed-off-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> Signed-off-by: Jan Lasek Signed-off-by: Hongbin Liu Signed-off-by: Tamerlan Tabolov Signed-off-by: Gerald Shen Signed-off-by: smajumdar Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Signed-off-by: Stas Bekman Signed-off-by: Jocelyn Huang Signed-off-by: Abhinav Khattar Signed-off-by: GiacomoLeoneMaria Signed-off-by: fayejf <36722593+fayejf@users.noreply.github.com> Signed-off-by: arendu Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Signed-off-by: Nithin Rao Koluguri Signed-off-by: jasonwan Signed-off-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Signed-off-by: Adi Renduchintala Signed-off-by: arendu Signed-off-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Signed-off-by: Chen Cui Signed-off-by: BestJuly Signed-off-by: Elena Rastorgueva Signed-off-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Signed-off-by: dimapihtar Signed-off-by: George <37293288+Jorjeous@users.noreply.github.com> Signed-off-by: Sangkug Lym Signed-off-by: Mehadi Hasan Menon Signed-off-by: Ante Jukić Signed-off-by: Maanu Grover Signed-off-by: Sasha Meister Signed-off-by: Jason Wang Signed-off-by: Tim Moon Signed-off-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Signed-off-by: eharper Signed-off-by: Jimmy Zhang Signed-off-by: George Zelenfroynd Signed-off-by: Anton Peganov Signed-off-by: Nikolay Karpov Signed-off-by: Yi Dong Signed-off-by: fayejf Signed-off-by: Igor Gitman Signed-off-by: Jan Baczek Signed-off-by: Mikołaj Błaż Signed-off-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Signed-off-by: Seonghun Noh Signed-off-by: Seonghun Signed-off-by: Eric Harper Signed-off-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: Samuele Cornell Co-authored-by: Parth Mannan Co-authored-by: Mingyuan Ma Co-authored-by: Lukasz Pierscieniewski Co-authored-by: Yu Yao Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yang Zhang Co-authored-by: Kunal Dhawan Co-authored-by: Aleksandr Laptev Co-authored-by: Robin Dong Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> Co-authored-by: Eric Harper Co-authored-by: Jason Co-authored-by: Ryan Langman Co-authored-by: Maxime Burchi <60737204+burchim@users.noreply.github.com> Co-authored-by: Igor Gitman Co-authored-by: Jan Lasek Co-authored-by: Kelvin Liu Co-authored-by: Hongbin Liu Co-authored-by: Tamerlan Tabolov Co-authored-by: Gerald Shen <119401249+gshennvm@users.noreply.github.com> Co-authored-by: Somshubra Majumdar Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Stas Bekman Co-authored-by: Jocelyn Co-authored-by: Abhinav Khattar Co-authored-by: Giacomo Leone Maria Cavallini <72698188+GiacomoLeoneMaria@users.noreply.github.com> Co-authored-by: fayejf <36722593+fayejf@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Co-authored-by: Nithin Rao Co-authored-by: meatybobby Co-authored-by: Chen Cui Co-authored-by: Marc Romeyn Co-authored-by: jasonwan Co-authored-by: hkelly33 <58792115+hkelly33@users.noreply.github.com> Co-authored-by: Yuanzhe Dong Co-authored-by: Cheng-Ping Hsieh Co-authored-by: Li Tao Co-authored-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Co-authored-by: Igor Gitman Co-authored-by: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Co-authored-by: George <37293288+Jorjeous@users.noreply.github.com> Co-authored-by: Sangkug Lym Co-authored-by: Mehadi Hasan Menon Co-authored-by: anteju <108555623+anteju@users.noreply.github.com> Co-authored-by: Ao Tang Co-authored-by: Ahmad Kiswani Co-authored-by: Bobby Chen Co-authored-by: Sasha Meister <117230141+ssh-meister@users.noreply.github.com> Co-authored-by: Maanu Grover <109391026+maanug-nv@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Co-authored-by: Jimmy Zhang Co-authored-by: PeganovAnton Co-authored-by: Nikolay Karpov Co-authored-by: Evelina <10428420+ekmb@users.noreply.github.com> Co-authored-by: Yi Dong <43824965+yidong72@users.noreply.github.com> Co-authored-by: jbaczek <45043825+jbaczek@users.noreply.github.com> Co-authored-by: mikolajblaz Co-authored-by: Seonghun Noh Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: Szymon Mikler --- examples/multimodal/convert_ckpt_to_nemo.py | 2 +- .../controlnet/conf/controlnet_infer.yaml | 36 + .../controlnet/conf/controlnet_v1-5.yaml | 222 ++ .../controlnet/controlnet_infer.py | 251 ++ .../controlnet/controlnet_train.py | 50 + .../text_to_image/convert_hf_ckpt_to_nemo.py | 226 ++ .../dreambooth/conf/dreambooth.yaml | 224 ++ .../dreambooth/conf/dreambooth_infer.yaml | 32 + .../text_to_image/dreambooth/dreambooth.py | 108 + .../dreambooth/dreambooth_infer.py | 44 + .../multimodal/text_to_image/imagen/README.md | 104 + .../text_to_image/imagen/conf/base64-2b.yaml | 142 ++ .../imagen/conf/base64-500m-edm.yaml | 136 ++ .../imagen/conf/base64-500m.yaml | 144 ++ .../conf/base64-500m_online_encoding.yaml | 137 ++ .../imagen/conf/fid_inference.yaml | 26 + .../imagen/conf/imagen_fid_images.yaml | 57 + .../imagen/conf/inference_pipeline.yaml | 42 + .../imagen/conf/sr1024-600m.yaml | 145 ++ .../imagen/conf/sr256-400m-edm.yaml | 222 ++ .../text_to_image/imagen/conf/sr256-400m.yaml | 150 ++ .../imagen/conf/sr256-450m-edm.yaml | 222 ++ .../imagen/conf/sr256-600m-edm-noise.yaml | 142 ++ .../imagen/conf/sr256-600m-edm.yaml | 219 ++ .../text_to_image/imagen/conf/sr256-600m.yaml | 146 ++ .../imagen/generate_fid_images.py | 116 + .../imagen/imagen_generate_images.py | 79 + .../text_to_image/imagen/imagen_infer.py | 50 + .../text_to_image/imagen/imagen_training.py | 63 + .../instruct_pix2pix/conf/sd_edit.yaml | 23 + .../instruct_pix2pix/conf/sd_finetune.yaml | 168 ++ .../instruct_pix2pix/sd_edit_cli.py | 168 ++ .../instruct_pix2pix/sd_finetune.py | 43 + .../stable_diffusion/conf/sd2_train.yaml | 192 ++ .../stable_diffusion/conf/sd_fid_images.yaml | 45 + .../stable_diffusion/conf/sd_infer.yaml | 31 + .../stable_diffusion/conf/sd_train.yaml | 208 ++ .../stable_diffusion/generate_fid_images.py | 97 + .../stable_diffusion/sd_infer.py | 44 + .../stable_diffusion/sd_train.py | 85 + .../clip/convert_external_clip_to_nemo.py | 2 +- .../clip/megatron_clip_imagenet_zeroshot.py | 2 +- .../clip/megatron_clip_infer.py | 2 +- .../clip/megatron_clip_pretrain.py | 2 +- .../x_to_nerf/benchmark_callback.py | 96 + .../multimodal/x_to_nerf/config/config.yaml | 52 + .../config/model/background/random.yaml | 3 + .../config/model/background/static.yaml | 2 + .../config/model/background/tcnn.yaml | 19 + .../config/model/background/torchngp.yaml | 11 + .../x_to_nerf/config/model/data/data.yaml | 41 + .../config/model/dreamfusion-dmtet.yaml | 40 + .../x_to_nerf/config/model/dreamfusion.yaml | 40 + .../config/model/guidance/sd_huggingface.yaml | 4 + .../config/model/guidance/sd_nemo.yaml | 4 + .../config/model/guidance/sd_trt.yaml | 5 + .../x_to_nerf/config/model/loss/dmtet.yaml | 8 + .../config/model/loss/dreamfusion.yaml | 8 + .../config/model/material/basic_shading.yaml | 1 + .../x_to_nerf/config/model/nerf/tcnn.yaml | 32 + .../x_to_nerf/config/model/nerf/torchngp.yaml | 26 + .../x_to_nerf/config/model/optim/adan.yaml | 6 + .../config/model/renderer/nerfacc.yaml | 8 + .../config/model/renderer/nvdiffrast.yaml | 6 + .../model/renderer/torchngp_raymarching.yaml | 7 + examples/multimodal/x_to_nerf/data.py | 86 + examples/multimodal/x_to_nerf/main.py | 70 + .../data/clip/augmentations/augmentations.py | 28 +- .../multimodal/data/controlnet/__init__.py | 13 + .../data/controlnet/controlnet_dataset.py | 145 ++ .../multimodal/data/dreambooth/__init__.py | 13 + .../data/dreambooth/dreambooth_dataset.py | 164 ++ .../multimodal/data/imagen/__init__.py | 13 + .../data/imagen/augmentations/__init__.py | 13 + .../imagen/augmentations/augmentations.py | 76 + .../data/imagen/augmentations/corruption.py | 39 + .../multimodal/data/imagen/imagen_dataset.py | 156 ++ .../data/instruct_pix2pix/__init__.py | 13 + .../data/instruct_pix2pix/edit_dataset.py | 137 ++ .../multimodal/data/nerf/__init__.py | 13 + .../multimodal/data/nerf/cameras.py | 192 ++ .../multimodal/data/nerf/circle_poses.py | 228 ++ .../multimodal/data/nerf/random_poses.py | 450 ++++ .../collections/multimodal/data/nerf/utils.py | 217 ++ .../multimodal/data/neva/conversation.py | 9 - .../data/stable_diffusion/__init__.py | 13 + .../stable_diffusion/augmentation/__init__.py | 13 + .../augmentation/augmentations.py | 75 + .../stable_diffusion_dataset.py | 185 ++ .../multimodal/models/nerf/__init__.py | 13 + .../multimodal/models/nerf/base.py | 36 + .../multimodal/models/nerf/dreamfusion.py | 325 +++ .../multimodal/models/nerf/txt2nerf_base.py | 93 + .../models/text_to_image/__init__.py | 13 + .../text_to_image/controlnet/__init__.py | 13 + .../text_to_image/controlnet/controlnet.py | 1023 ++++++++ .../models/text_to_image/controlnet/util.py | 102 + .../text_to_image/dreambooth/__init__.py | 13 + .../text_to_image/dreambooth/dreambooth.py | 639 +++++ .../models/text_to_image/dreambooth/util.py | 167 ++ .../models/text_to_image/imagen/__init__.py | 13 + .../models/text_to_image/imagen/imagen.py | 598 +++++ .../text_to_image/imagen/imagen_pipeline.py | 356 +++ .../models/text_to_image/imagen/precond.py | 174 ++ .../instruct_pix2pix/__init__.py | 13 + .../instruct_pix2pix/ldm/__init__.py | 13 + .../instruct_pix2pix/ldm/ddpm_edit.py | 262 ++ .../stable_diffusion/__init__.py | 13 + .../stable_diffusion/diffusion_model.py | 80 + .../stable_diffusion/ldm/__init__.py | 13 + .../stable_diffusion/ldm/autoencoder.py | 614 +++++ .../stable_diffusion/ldm/ddpm.py | 2163 +++++++++++++++++ .../stable_diffusion/ldm_config.py | 144 ++ .../stable_diffusion/samplers/__init__.py | 16 + .../stable_diffusion/samplers/base_sampler.py | 339 +++ .../stable_diffusion/samplers/ddim.py | 119 + .../stable_diffusion/samplers/dpmsolver.py | 493 ++++ .../stable_diffusion/samplers/k_diffusion.py | 838 +++++++ .../stable_diffusion/samplers/para_ddim.py | 231 ++ .../stable_diffusion/samplers/plms.py | 105 + .../stable_diffusion/samplers/sampler_dpm.py | 76 + .../multimodal/modules/imagen/__init__.py | 24 + .../imagen/diffusionmodules/__init__.py | 24 + .../imagen/diffusionmodules/attention.py | 317 +++ .../imagen/diffusionmodules/attention_alt.py | 321 +++ .../modules/imagen/diffusionmodules/blocks.py | 905 +++++++ .../modules/imagen/diffusionmodules/embs.py | 69 + .../modules/imagen/diffusionmodules/layers.py | 240 ++ .../modules/imagen/diffusionmodules/nets.py | 698 ++++++ .../modules/imagen/encoder/__init__.py | 24 + .../modules/imagen/encoder/t5encoder.json | 51 + .../modules/imagen/encoder/t5encoder.py | 68 + .../modules/imagen/sampler/__init__.py | 24 + .../modules/imagen/sampler/batch_ops.py | 57 + .../modules/imagen/sampler/continuous_ddpm.py | 168 ++ .../modules/imagen/sampler/sampler.py | 250 ++ .../multimodal/modules/nerf/__init__.py | 13 + .../modules/nerf/background/__init__.py | 13 + .../nerf/background/nerf_background_base.py | 35 + .../nerf/background/random_background.py | 32 + .../nerf/background/static_background.py | 27 + .../nerf/background/tcnn_background.py | 45 + .../nerf/background/torchngp_background.py | 44 + .../modules/nerf/geometry/__init__.py | 13 + .../multimodal/modules/nerf/geometry/dmtet.py | 163 ++ .../modules/nerf/geometry/layers.py | 142 ++ .../modules/nerf/geometry/nerf_base.py | 373 +++ .../modules/nerf/geometry/tcnn_nerf.py | 121 + .../modules/nerf/geometry/torchngp_nerf.py | 127 + .../modules/nerf/guidance/__init__.py | 13 + .../stablediffusion_huggingface_pipeline.py | 155 ++ .../guidance/stablediffusion_nemo_pipeline.py | 141 ++ .../guidance/stablediffusion_trt_pipeline.py | 234 ++ .../nerf/guidance/txt2img_guidance_base.py | 19 + .../multimodal/modules/nerf/loss/__init__.py | 13 + .../nerf/loss/laplacian_smooth_loss.py | 51 + .../nerf/loss/normal_consistency_loss.py | 69 + .../modules/nerf/materials/__init__.py | 13 + .../modules/nerf/materials/basic_shading.py | 79 + .../modules/nerf/materials/materials_base.py | 41 + .../modules/nerf/renderers/__init__.py | 13 + .../modules/nerf/renderers/base_renderer.py | 31 + .../nerf/renderers/base_sdf_renderer.py | 33 + .../nerf/renderers/base_volume_renderer.py | 19 + .../nerf/renderers/nerfacc_volume_renderer.py | 376 +++ .../nerf/renderers/nvdiffrast_renderer.py | 235 ++ .../renderers/torchngp_volume_renderer.py | 288 +++ .../multimodal/modules/nerf/utils/__init__.py | 13 + .../modules/nerf/utils/activation.py | 33 + .../modules/nerf/utils/torch_ngp/__init__.py | 13 + .../modules/nerf/utils/torch_ngp/encoding.py | 149 ++ .../nerf/utils/torch_ngp/freqencoder.py | 84 + .../nerf/utils/torch_ngp/gridencoder.py | 299 +++ .../nerf/utils/torch_ngp/raymarching.py | 561 +++++ .../modules/nerf/utils/torch_ngp/shencoder.py | 93 + .../modules/nerf/utils/trt_engine.py | 170 ++ .../modules/stable_diffusion/__init__.py | 13 + .../modules/stable_diffusion/attention.py | 408 ++++ .../diffusionmodules/__init__.py | 13 + .../diffusionmodules/model.py | 881 +++++++ .../diffusionmodules/openaimodel.py | 1175 +++++++++ .../stable_diffusion/diffusionmodules/util.py | 319 +++ .../distributions/__init__.py | 13 + .../distributions/distributions.py | 98 + .../stable_diffusion/encoders/__init__.py | 13 + .../stable_diffusion/encoders/modules.py | 470 ++++ .../encoders/x_transformer.py | 630 +++++ .../multimodal/parts/imagen/__init__.py | 13 + .../multimodal/parts/imagen/utils.py | 29 + .../parts/stable_diffusion/__init__.py | 13 + .../parts/stable_diffusion/pipeline.py | 202 ++ .../parts/stable_diffusion/utils.py | 208 ++ .../tts/models/spectrogram_enhancer.py | 9 +- .../vision/data/megatron/image_folder.py | 235 +- .../vision/data/megatron/vit_dataset.py | 9 +- 195 files changed, 28782 insertions(+), 134 deletions(-) create mode 100644 examples/multimodal/text_to_image/controlnet/conf/controlnet_infer.yaml create mode 100644 examples/multimodal/text_to_image/controlnet/conf/controlnet_v1-5.yaml create mode 100644 examples/multimodal/text_to_image/controlnet/controlnet_infer.py create mode 100644 examples/multimodal/text_to_image/controlnet/controlnet_train.py create mode 100644 examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py create mode 100644 examples/multimodal/text_to_image/dreambooth/conf/dreambooth.yaml create mode 100644 examples/multimodal/text_to_image/dreambooth/conf/dreambooth_infer.yaml create mode 100644 examples/multimodal/text_to_image/dreambooth/dreambooth.py create mode 100644 examples/multimodal/text_to_image/dreambooth/dreambooth_infer.py create mode 100644 examples/multimodal/text_to_image/imagen/README.md create mode 100644 examples/multimodal/text_to_image/imagen/conf/base64-2b.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/base64-500m-edm.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/base64-500m.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/base64-500m_online_encoding.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/fid_inference.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/imagen_fid_images.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/inference_pipeline.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr1024-600m.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-400m-edm.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-400m.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-450m-edm.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm-noise.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm.yaml create mode 100644 examples/multimodal/text_to_image/imagen/conf/sr256-600m.yaml create mode 100644 examples/multimodal/text_to_image/imagen/generate_fid_images.py create mode 100644 examples/multimodal/text_to_image/imagen/imagen_generate_images.py create mode 100644 examples/multimodal/text_to_image/imagen/imagen_infer.py create mode 100644 examples/multimodal/text_to_image/imagen/imagen_training.py create mode 100644 examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_edit.yaml create mode 100644 examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_finetune.yaml create mode 100644 examples/multimodal/text_to_image/instruct_pix2pix/sd_edit_cli.py create mode 100644 examples/multimodal/text_to_image/instruct_pix2pix/sd_finetune.py create mode 100644 examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml create mode 100644 examples/multimodal/text_to_image/stable_diffusion/conf/sd_fid_images.yaml create mode 100644 examples/multimodal/text_to_image/stable_diffusion/conf/sd_infer.yaml create mode 100644 examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml create mode 100644 examples/multimodal/text_to_image/stable_diffusion/generate_fid_images.py create mode 100644 examples/multimodal/text_to_image/stable_diffusion/sd_infer.py create mode 100644 examples/multimodal/text_to_image/stable_diffusion/sd_train.py create mode 100644 examples/multimodal/x_to_nerf/benchmark_callback.py create mode 100644 examples/multimodal/x_to_nerf/config/config.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/background/random.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/background/static.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/background/tcnn.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/background/torchngp.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/data/data.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/dreamfusion-dmtet.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/dreamfusion.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/guidance/sd_huggingface.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/guidance/sd_nemo.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/guidance/sd_trt.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/loss/dmtet.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/loss/dreamfusion.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/material/basic_shading.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/nerf/tcnn.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/nerf/torchngp.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/optim/adan.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/renderer/nerfacc.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/renderer/nvdiffrast.yaml create mode 100644 examples/multimodal/x_to_nerf/config/model/renderer/torchngp_raymarching.yaml create mode 100644 examples/multimodal/x_to_nerf/data.py create mode 100644 examples/multimodal/x_to_nerf/main.py create mode 100644 nemo/collections/multimodal/data/controlnet/__init__.py create mode 100644 nemo/collections/multimodal/data/controlnet/controlnet_dataset.py create mode 100644 nemo/collections/multimodal/data/dreambooth/__init__.py create mode 100644 nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py create mode 100644 nemo/collections/multimodal/data/imagen/__init__.py create mode 100644 nemo/collections/multimodal/data/imagen/augmentations/__init__.py create mode 100644 nemo/collections/multimodal/data/imagen/augmentations/augmentations.py create mode 100644 nemo/collections/multimodal/data/imagen/augmentations/corruption.py create mode 100644 nemo/collections/multimodal/data/imagen/imagen_dataset.py create mode 100644 nemo/collections/multimodal/data/instruct_pix2pix/__init__.py create mode 100644 nemo/collections/multimodal/data/instruct_pix2pix/edit_dataset.py create mode 100644 nemo/collections/multimodal/data/nerf/__init__.py create mode 100644 nemo/collections/multimodal/data/nerf/cameras.py create mode 100644 nemo/collections/multimodal/data/nerf/circle_poses.py create mode 100644 nemo/collections/multimodal/data/nerf/random_poses.py create mode 100644 nemo/collections/multimodal/data/nerf/utils.py create mode 100644 nemo/collections/multimodal/data/stable_diffusion/__init__.py create mode 100644 nemo/collections/multimodal/data/stable_diffusion/augmentation/__init__.py create mode 100644 nemo/collections/multimodal/data/stable_diffusion/augmentation/augmentations.py create mode 100644 nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py create mode 100644 nemo/collections/multimodal/models/nerf/__init__.py create mode 100644 nemo/collections/multimodal/models/nerf/base.py create mode 100644 nemo/collections/multimodal/models/nerf/dreamfusion.py create mode 100644 nemo/collections/multimodal/models/nerf/txt2nerf_base.py create mode 100644 nemo/collections/multimodal/models/text_to_image/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/controlnet/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py create mode 100644 nemo/collections/multimodal/models/text_to_image/controlnet/util.py create mode 100644 nemo/collections/multimodal/models/text_to_image/dreambooth/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py create mode 100644 nemo/collections/multimodal/models/text_to_image/dreambooth/util.py create mode 100644 nemo/collections/multimodal/models/text_to_image/imagen/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/imagen/imagen.py create mode 100644 nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py create mode 100644 nemo/collections/multimodal/models/text_to_image/imagen/precond.py create mode 100644 nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/ddpm_edit.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_model.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm_config.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/__init__.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/base_sampler.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/ddim.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/dpmsolver.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/k_diffusion.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/para_ddim.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/plms.py create mode 100644 nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/sampler_dpm.py create mode 100644 nemo/collections/multimodal/modules/imagen/__init__.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/__init__.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/attention_alt.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/embs.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py create mode 100644 nemo/collections/multimodal/modules/imagen/diffusionmodules/nets.py create mode 100644 nemo/collections/multimodal/modules/imagen/encoder/__init__.py create mode 100644 nemo/collections/multimodal/modules/imagen/encoder/t5encoder.json create mode 100644 nemo/collections/multimodal/modules/imagen/encoder/t5encoder.py create mode 100644 nemo/collections/multimodal/modules/imagen/sampler/__init__.py create mode 100644 nemo/collections/multimodal/modules/imagen/sampler/batch_ops.py create mode 100644 nemo/collections/multimodal/modules/imagen/sampler/continuous_ddpm.py create mode 100644 nemo/collections/multimodal/modules/imagen/sampler/sampler.py create mode 100644 nemo/collections/multimodal/modules/nerf/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/nerf_background_base.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/random_background.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/static_background.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/tcnn_background.py create mode 100644 nemo/collections/multimodal/modules/nerf/background/torchngp_background.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/dmtet.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/layers.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/nerf_base.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/tcnn_nerf.py create mode 100644 nemo/collections/multimodal/modules/nerf/geometry/torchngp_nerf.py create mode 100644 nemo/collections/multimodal/modules/nerf/guidance/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_huggingface_pipeline.py create mode 100644 nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_nemo_pipeline.py create mode 100644 nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_trt_pipeline.py create mode 100644 nemo/collections/multimodal/modules/nerf/guidance/txt2img_guidance_base.py create mode 100644 nemo/collections/multimodal/modules/nerf/loss/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/loss/laplacian_smooth_loss.py create mode 100644 nemo/collections/multimodal/modules/nerf/loss/normal_consistency_loss.py create mode 100644 nemo/collections/multimodal/modules/nerf/materials/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/materials/basic_shading.py create mode 100644 nemo/collections/multimodal/modules/nerf/materials/materials_base.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/base_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/base_sdf_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/base_volume_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/nerfacc_volume_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/nvdiffrast_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/renderers/torchngp_volume_renderer.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/activation.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/__init__.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/encoding.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/freqencoder.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/gridencoder.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/raymarching.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/torch_ngp/shencoder.py create mode 100644 nemo/collections/multimodal/modules/nerf/utils/trt_engine.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/__init__.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/attention.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/__init__.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/distributions/__init__.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/distributions/distributions.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/encoders/__init__.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py create mode 100644 nemo/collections/multimodal/modules/stable_diffusion/encoders/x_transformer.py create mode 100644 nemo/collections/multimodal/parts/imagen/__init__.py create mode 100644 nemo/collections/multimodal/parts/imagen/utils.py create mode 100644 nemo/collections/multimodal/parts/stable_diffusion/__init__.py create mode 100644 nemo/collections/multimodal/parts/stable_diffusion/pipeline.py create mode 100644 nemo/collections/multimodal/parts/stable_diffusion/utils.py diff --git a/examples/multimodal/convert_ckpt_to_nemo.py b/examples/multimodal/convert_ckpt_to_nemo.py index b1d6f3e04204..2c36f434a075 100644 --- a/examples/multimodal/convert_ckpt_to_nemo.py +++ b/examples/multimodal/convert_ckpt_to_nemo.py @@ -36,7 +36,7 @@ from nemo.collections.multimodal.models.text_to_image.imagen import MegatronImagen from nemo.collections.multimodal.models.text_to_image.instruct_pix2pix.ldm.ddpm_edit import MegatronLatentDiffusionEdit from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion -from nemo.collections.multimodal.models.vision_language_foundation.clip import MegatronCLIPModel +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.utils import AppState, logging diff --git a/examples/multimodal/text_to_image/controlnet/conf/controlnet_infer.yaml b/examples/multimodal/text_to_image/controlnet/conf/controlnet_infer.yaml new file mode 100644 index 000000000000..bcf56d599cc2 --- /dev/null +++ b/examples/multimodal/text_to_image/controlnet/conf/controlnet_infer.yaml @@ -0,0 +1,36 @@ +name: stable-diffusion-train + +infer: + unconditional_guidance_scale: 3 + num_images_per_prompt: 4 + hint_image_size: 512 + height: 512 + width: 512 + down_factor: 8 + inference_steps: 50 + sampler_type: 'DDIM' + eta: 0 + output_type: 'pil' + save_to_file: True + out_path: 'controlnet' + seed: 355 + prompts: + - high quality picture of a house in oil painting style + control: + - /datasets/coco-stuff/house.png #images/val2017/000000001584.jpg + # Depending on the input control, if the input control is already the conditioning image, null should be passed here + # If a reconstruction target is used as control, then preprocessing function that turns it into a conditioning image needs to be specified + control_image_preprocess: + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + +model: + restore_from_path: /ckpts/controlnet/30k.nemo + precision: ${trainer.precision} + strength: 2.0 + guess_mode: False \ No newline at end of file diff --git a/examples/multimodal/text_to_image/controlnet/conf/controlnet_v1-5.yaml b/examples/multimodal/text_to_image/controlnet/conf/controlnet_v1-5.yaml new file mode 100644 index 000000000000..13ca53e835f2 --- /dev/null +++ b/examples/multimodal/text_to_image/controlnet/conf/controlnet_v1-5.yaml @@ -0,0 +1,222 @@ +trainer: + devices: 2 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: True + max_epochs: 3 # PTL default. In practice, max_steps will be reached first. + max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: controlnet + create_wandb_logger: False + wandb_logger_kwargs: + project: stable-diffusion + group: controlnet + name: controlnet-v1.5 + resume: True + create_checkpoint_callback: True + create_tensorboard_logger: True + checkpoint_callback_params: + save_top_k: -1 + every_n_train_steps: 5000 + every_n_epochs: 0 + monitor: reduced_train_loss + filename: 'controlnet--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 4 # limited by GPU memory + global_batch_size: 8 + + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: images + cond_stage_key: captions + control_key: hint + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + scale_by_std: False + ckpt_path: + ignore_keys: [ ] + parameterization: eps + clip_denoised: True + load_only_unet: False + cosine_s: 8e-3 + given_betas: + original_elbo_weight: 0 + v_posterior: 0 + l_simple_weight: 1 + use_positional_encodings: False + learn_logvar: False + logvar_init: 0 + beta_schedule: linear + loss_type: l2 + learning_rate: 1.0e-04 + concat_mode: True + cond_stage_forward: + text_embedding_dropout_rate: 0.0 + fused_opt: True + inductor: False + inductor_cudagraphs: False + capture_cudagraph_iters: -1 # -1 to disable + channels_last: True + only_mid_control: False + sd_locked: True + + control_stage_config: + _target_: nemo.collections.multimodal.models.controlnet.controlnet.ControlNet + params: + from_pretrained_unet: /ckpts/v1-5-pruned.ckpt + from_NeMo: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 3 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + use_linear_in_transformer: False + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + use_flash_attention: False + + unet_config: + _target_: nemo.collections.multimodal.models.controlnet.controlnet.ControlledUnetModel + from_pretrained: /ckpts/v1-5-pruned.ckpt + from_NeMo: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + use_flash_attention: False + + first_stage_config: + _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKL + from_pretrained: /ckpts/vae.bin + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + version: openai/clip-vit-large-patch14 + device: cuda + max_length: 77 + + data: + num_workers: 16 + synthetic_data: False # dataset_path and local_root_path can be empty when using synthetic data + synthetic_data_length: 10000 + train: + dataset_path: + #- /datasets/tarfiles/fill50k.pkl + - /datasets/coco-stuff/coco-stuff-tarfiles/wdinfo-coco-stuff.pkl + augmentations: + resize_smallest_side: 512 + center_crop_h_w: 512, 512 + horizontal_flip: False + filterings: + + webdataset: + infinite_sampler: False + local_root_path: /datasets/coco-stuff/coco-stuff-tarfiles + + optim: + name: fused_adam + lr: 2e-5 + weight_decay: 0. + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 0 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + image_logger: + batch_frequency: 1000 + max_images: 4 + + #miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) diff --git a/examples/multimodal/text_to_image/controlnet/controlnet_infer.py b/examples/multimodal/text_to_image/controlnet/controlnet_infer.py new file mode 100644 index 000000000000..4cdf922f8211 --- /dev/null +++ b/examples/multimodal/text_to_image/controlnet/controlnet_infer.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time + +import cv2 +import einops +import torch +from PIL import Image + +from nemo.collections.multimodal.models.text_to_image.controlnet import get_preprocessing_function +from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.ddim import DDIMSampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.plms import PLMSSampler +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.core.config import hydra_runner + + +def get_control_input(image_path, batch_size, hint_image_size, control_image_preprocess=None): + image = cv2.imread(image_path) + if control_image_preprocess: + # More applications will be supported here + process = get_preprocessing_function(control_image_preprocess) + image = process(image) + image = cv2.resize(image, (hint_image_size, hint_image_size)) + control = torch.from_numpy(image).float() / 255.0 + control = torch.stack([control for _ in range(batch_size)], dim=0) + control = einops.rearrange(control, 'b h w c -> b c h w') + return control + + +def encode_prompt(cond_stage_model, prompt, unconditional_guidance_scale, batch_size): + c = cond_stage_model.encode(batch_size * [prompt]) + if unconditional_guidance_scale != 1.0: + uc = cond_stage_model.encode(batch_size * [""]) + else: + uc = None + return c, uc + + +def initialize_sampler(model, sampler_type): + if sampler_type == 'DDIM': + sampler = DDIMSampler(model) + elif sampler_type == 'PLMS': + sampler = PLMSSampler(model) + else: + raise ValueError(f'Sampler {sampler_type} is not supported for {cls.__name__}') + return sampler + + +def decode_images(model, samples): + images = model.decode_first_stage(samples) + + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + + return images + + +def torch_to_numpy(images): + numpy_images = [x.float().cpu().permute(0, 2, 3, 1).numpy() for x in images] + return numpy_images + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def pipeline(model, cfg, rng=None, verbose=True): + # setup default values for inference configs + unconditional_guidance_scale = cfg.infer.get("unconditional_guidance_scale", 7.5) + batch_size = cfg.infer.get('num_images_per_prompt', 1) + prompts = cfg.infer.get('prompts', []) + control = cfg.infer.get('control', []) + height = cfg.infer.get('height', 512) + width = cfg.infer.get('width', 512) + downsampling_factor = cfg.infer.get('down_factor', 8) + sampler_type = cfg.infer.get('sampler_type', 'DDIM') + inference_steps = cfg.infer.get('inference_steps', 50) + output_type = cfg.infer.get('output_type', 'pil') + save_to_file = cfg.infer.get('save_to_file', True) + out_path = cfg.infer.get('out_path', '') + eta = cfg.infer.get('eta', 0) + guess_mode = cfg.model.get('guess_mode', False) + hint_image_size = cfg.infer.get('hint_image_size', 512) + control_image_preprocess = cfg.infer.get('control_image_preprocess', None) + + # get autocast_dtype + if cfg.trainer.precision in ['bf16', 'bf16-mixed']: + autocast_dtype = torch.bfloat16 + elif cfg.trainer.precision in [32, '32', '32-true']: + autocast_dtype = torch.float + elif cfg.trainer.precision in [16, '16', '16-mixed']: + autocast_dtype = torch.half + else: + raise ValueError('precision must be in [32, 16, "bf16"]') + + with torch.no_grad(), torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + ): + + in_channels = model.model.diffusion_model.in_channels + + sampler = initialize_sampler(model, sampler_type.upper()) + + output = [] + throughput = [] + + if isinstance(prompts, str): + prompts = [prompts] + + assert len(prompts) == len(control) + + for control, prompt in zip(control, prompts): + tic = time.perf_counter() + tic_total = tic + txt_cond, txt_u_cond = encode_prompt( + model.cond_stage_model, prompt, unconditional_guidance_scale, batch_size + ) + + control = get_control_input(control, batch_size, hint_image_size, control_image_preprocess).to( + torch.cuda.current_device(), dtype=autocast_dtype + ) + + cond = {"c_concat": control, "c_crossattn": txt_cond} + u_cond = {"c_concat": None if guess_mode else control, "c_crossattn": txt_u_cond} + + toc = time.perf_counter() + conditioning_time = toc - tic + + latent_shape = [batch_size, height // downsampling_factor, width // downsampling_factor] + latents = torch.randn( + [batch_size, in_channels, height // downsampling_factor, width // downsampling_factor], generator=rng + ).to(torch.cuda.current_device()) + + tic = time.perf_counter() + samples, intermediates = sampler.sample( + S=inference_steps, + conditioning=cond, + batch_size=batch_size, + shape=latent_shape, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=u_cond, + eta=eta, + x_T=latents, + ) + toc = time.perf_counter() + sampling_time = toc - tic + + tic = time.perf_counter() + images = decode_images(model, samples) + toc = time.perf_counter() + decode_time = toc - tic + + toc_total = time.perf_counter() + total_time = toc_total - tic_total + output.append(images) + + throughput.append( + { + 'text-conditioning-time': conditioning_time, + 'sampling-time': sampling_time, + 'decode-time': decode_time, + 'total-time': total_time, + 'sampling-steps': inference_steps, + } + ) + + # Convert output type and save to disk + if output_type == 'torch': + output = torch.cat(output, dim=0) + else: + output = torch_to_numpy(output) + if output_type == 'pil': + output = [numpy_to_pil(x) for x in output] + + if save_to_file: + os.makedirs(out_path, exist_ok=True) + # Saving control map + control_image = control[0].float().cpu().permute(1, 2, 0).numpy() + control_image = Image.fromarray((control_image * 255).round().astype("uint8")) + control_image.save(os.path.join(out_path, f'{prompt[:50]}_control.png')) + if output_type == 'pil': + for text_prompt, pils in zip(prompts, output): + for idx, image in enumerate(pils): + image.save(os.path.join(out_path, f'{text_prompt[:50]}_{idx}.png')) + else: + with open(os.path.join(out_path, 'output.pkl'), 'wb') as f: + pickle.dump(output, f) + else: + return output + + ave_metrics = {} + for key in throughput[0].keys(): + ave_metrics[f'avg-{key}'] = sum([dicts[key] for dicts in throughput]) / len(throughput) + if verbose: + print(ave_metrics) + + +@hydra_runner(config_path='conf', config_name='controlnet_infer') +def main(cfg): + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + model_cfg.unet_config.from_pretrained = None + model_cfg.first_stage_config.from_pretrained = None + model_cfg.control_stage_config.from_pretrained_unet = None + model_cfg.channels_last = True + model_cfg.capture_cudagraph_iters = -1 + + torch.backends.cuda.matmul.allow_tf32 = True + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronControlNet, cfg=cfg, model_cfg_modifier=model_cfg_modifier + ) + model = megatron_diffusion_model.model + model.cuda().eval() + + guess_mode = cfg.model.guess_mode + model.contol_scales = ( + [cfg.model.strength * (0.825 ** float(12 - i)) for i in range(13)] + if guess_mode + else ([cfg.model.strength] * 13) + ) + + rng = torch.Generator().manual_seed(cfg.infer.seed) + pipeline(model, cfg, rng=rng) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/controlnet/controlnet_train.py b/examples/multimodal/text_to_image/controlnet/controlnet_train.py new file mode 100644 index 000000000000..239409f616f1 --- /dev/null +++ b/examples/multimodal/text_to_image/controlnet/controlnet_train.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning import Trainer + +from nemo.collections.multimodal.models.text_to_image.controlnet import ImageLogger +from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +class MegatronControlNetTrainerBuilder(MegatronTrainerBuilder): + """Builder for T5 model Trainer with overrides.""" + + def create_trainer(self, callbacks=[]) -> Trainer: + strategy = self._training_strategy() + plugins = self._plugins() + return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks) + + +@hydra_runner(config_path='conf', config_name='controlnet_v1-5.yaml') +def main(cfg): + callbacks = [] + + if cfg.model.get('image_logger', None): + callbacks.append(ImageLogger(**cfg.model.image_logger)) + + trainer = MegatronControlNetTrainerBuilder(cfg).create_trainer(callbacks=callbacks) + + exp_manager(trainer, cfg.get("exp_manager", None)) + + model = MegatronControlNet(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py new file mode 100644 index 000000000000..31cddbf20dde --- /dev/null +++ b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py @@ -0,0 +1,226 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage example: + python /opt/NeMo/examples/multimodal/generative/stable_diffusion/convert_hf_ckpt_to_nemo.py + --ckpt_path=path/to/hf.ckpt + --hparams_file=path/to/saved.yaml + --nemo_file_path=hf2sd.nemo + +Additionally, provide a NeMo hparams file with the correct model architecture arguments. Refer to examples/multimodal/foundation/clip/conf/megatron_clip_config.yaml. +""" + +import os +import tempfile +from argparse import ArgumentParser + +import torch +from lightning_fabric.utilities.cloud_io import _load as pl_load +from omegaconf import OmegaConf +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import AppState, logging +from nemo.utils.distributed import initialize_distributed + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--ckpt_path", type=str, default=None, required=True, help="Path to checkpoint.") + + parser.add_argument( + "--hparams_file", + type=str, + default=None, + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--nemo_file_path", type=str, default=None, required=True, help="Path to output .nemo file.") + parser.add_argument("--gpus_per_node", type=int, required=False, default=1) + parser.add_argument("--tensor_model_parallel_size", type=int, required=False, default=1) + parser.add_argument("--pipeline_model_parallel_size", type=int, required=False, default=1) + parser.add_argument( + "--pipeline_model_parallel_split_rank", + type=int, + required=False, + default=None, + help="If pipeline parallel size > 1, this is the rank at which the encoder ends and the decoder begins.", + ) + parser.add_argument("--local_rank", type=int, required=False, default=os.getenv('LOCAL_RANK', -1)) + parser.add_argument("--bcp", action="store_true", help="Whether on BCP platform") + parser.add_argument("--model_type", type=str, required=False, default="stable_diffusion") + parser.add_argument("--nemo_clip_path", type=str, required=False, help="Path to clip ckpt file in .nemo format") + + args = parser.parse_args() + return args + + +def load_config_and_state_from_nemo(nemo_path): + if torch.cuda.is_available(): + map_location = torch.device('cuda') + else: + map_location = torch.device('cpu') + save_restore_connector = NLPSaveRestoreConnector() + cwd = os.getcwd() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + save_restore_connector._unpack_nemo_file(path2file=nemo_path, out_folder=tmpdir) + + # Change current working directory to + os.chdir(tmpdir) + config_yaml = os.path.join(tmpdir, save_restore_connector.model_config_yaml) + cfg = OmegaConf.load(config_yaml) + + model_weights = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + state_dict = save_restore_connector._load_state_dict_from_disk(model_weights, map_location=map_location) + finally: + os.chdir(cwd) + + return cfg, state_dict + + +def mapping_hf_state_dict(hf_state_dict, model, clip_dict=None): + nemo_state = model.state_dict() + new_state_dict = {} + for k, v in hf_state_dict.items(): + k = 'model.' + k + # This is not necessary when you turn off model.inductor in config file + # if 'diffusion_model' in k: + # k = k.replace('diffusion_model', 'diffusion_model._orig_mod') + if 'in_layers' in k or 'out_layers' in k: + s = k.split('.') + idx = int(s[-2]) + if idx != 0: + k = ".".join(s[:-2] + [str(int(idx - 1))] + [s[-1]]) + if k in nemo_state: + new_state_dict[k] = v + if clip_dict: + for k, v in clip_dict.items(): + k = k.replace("model.text_encoder", "model.cond_stage_model.model") + if k in nemo_state: + new_state_dict[k] = v + for k in [ + 'betas', + 'alphas_cumprod', + 'alphas_cumprod_prev', + 'sqrt_alphas_cumprod', + 'sqrt_one_minus_alphas_cumprod', + 'log_one_minus_alphas_cumprod', + 'sqrt_recip_alphas_cumprod', + 'sqrt_recipm1_alphas_cumprod', + 'posterior_variance', + 'posterior_log_variance_clipped', + 'posterior_mean_coef1', + 'posterior_mean_coef2', + ]: + new_state_dict['model.' + k] = nemo_state['model.' + k] + + return new_state_dict + + +def convert(local_rank, rank, world_size, args): + app_state = AppState() + app_state.data_parallel_rank = 0 + num_nodes = world_size // args.gpus_per_node + if args.bcp: + trainer = Trainer( + devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu', plugins=[TorchElasticEnvironment()] + ) + else: + trainer = Trainer(devices=args.gpus_per_node, num_nodes=num_nodes, accelerator='gpu') + + app_state.pipeline_model_parallel_size = args.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = args.tensor_model_parallel_size + + # no use atm, use to split ranks in encoder/decoder models. + if args.pipeline_model_parallel_size > 1 and args.model_type in []: + if args.pipeline_model_parallel_split_rank is not None: + app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_split_rank + else: + if args.pipeline_model_parallel_size % 2 != 0: + raise ValueError( + f"Pipeline model parallel size {args.pipeline_model_parallel_size} must be even if split rank is not specified." + ) + else: + # If split rank is not set, then we set it to be pipeline_model_parallel_size // 2 - this is because in most cases we have the same number of enc/dec layers. + app_state.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2 + else: + app_state.pipeline_model_parallel_split_rank = None + + app_state.model_parallel_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size + + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + ) + + app_state.pipeline_model_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() + app_state.tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() + + if args.ckpt_path.endswith('safetensors'): + from safetensors.torch import load_file as load_safetensors + + checkpoint = load_safetensors(args.ckpt_path) + else: + checkpoint = pl_load(args.ckpt_path, map_location='cpu') + if 'state_dict' in checkpoint.keys(): + checkpoint = checkpoint['state_dict'] + cfg = OmegaConf.load(args.hparams_file) + cfg.model.inductor = False + if args.model_type == 'stable_diffusion': + model = MegatronLatentDiffusion(cfg.model, trainer) + elif args.model_type == 'controlnet': + model = MegatronControlNet(cfg.model, trainer) + + if 'nemo' in model.cfg.cond_stage_config._target_: + assert ( + args.nemo_clip_path is not None + ), "To align with current hparams file, you need to provide .nemo checkpoint of clip model for stable diffusion. If you want to convert HF clip checkpoint to .nemo checkpoint first, please refer to /opt/NeMo/examples/multimodal/foundation/clip/convert_external_clip_to_nemo.py" + _, clip_dict = load_config_and_state_from_nemo(args.nemo_clip_path) + else: + clip_dict = None + + state_dict = mapping_hf_state_dict(checkpoint, model, clip_dict=clip_dict) + + model._save_restore_connector = NLPSaveRestoreConnector() + + model.load_state_dict(state_dict) + + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + model.save_to(args.nemo_file_path) + + logging.info(f'NeMo model saved to: {args.nemo_file_path}') + + +if __name__ == '__main__': + args = get_args() + local_rank, rank, world_size = initialize_distributed(args) + convert(local_rank, rank, world_size, args) diff --git a/examples/multimodal/text_to_image/dreambooth/conf/dreambooth.yaml b/examples/multimodal/text_to_image/dreambooth/conf/dreambooth.yaml new file mode 100644 index 000000000000..37e9b284e219 --- /dev/null +++ b/examples/multimodal/text_to_image/dreambooth/conf/dreambooth.yaml @@ -0,0 +1,224 @@ +name: Dreambooth + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16-mixed + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 400 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + +exp_manager: + exp_dir: null + name: ${name} + create_checkpoint_callback: True + create_tensorboard_logger: True + checkpoint_callback_params: + every_n_train_steps: 200 + every_n_epochs: 0 + monitor: reduced_train_loss + save_on_train_epoch_end: False + filename: '${name}-{step}' + save_top_k: -1 + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 2 # limited by GPU memory + global_batch_size: 2 # will use more micro batches to reach global batch size + + with_prior_preservation: False + use_cached_latents: True + prior_loss_weight: 0.5 + train_text_encoder: False + restore_from_path: /ckpts/nemo-v1-5-188000-ema.nemo #This ckpt is only used to generate regularization images, thus .nemo ckpt is needed + + + + + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: images + cond_stage_key: captions + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn # check + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + scale_by_std: False + ckpt_path: + ignore_keys: [ ] + parameterization: eps + clip_denoised: True + load_only_unet: False + cosine_s: 8e-3 + given_betas: + original_elbo_weight: 0 + v_posterior: 0 + l_simple_weight: 1 + use_positional_encodings: False + learn_logvar: False + logvar_init: 0 + beta_schedule: linear + loss_type: l2 + + concat_mode: True + cond_stage_forward: + text_embedding_dropout_rate: 0.1 + fused_opt: True + inductor: False + inductor_cudagraphs: False + channels_last: False + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: /ckpts/unet.bin #load unet weights for finetuning, can use .ckpt ckpts from various sources + from_NeMo: False #Must be specified when from pretrained is not None, False means loading unet from HF ckpt + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + use_flash_attention: False + + first_stage_config: + _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKL + from_pretrained: /ckpts/vae.bin + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 #Never used + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder + restore_from_path: /ckpts/openai.nemo + device: cuda + freeze: True + layer: "last" + # For compatibility of history version that uses HF clip model + # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + # version: openai/clip-vit-large-patch14 + # device: cuda + # max_length: 77 + + noise_scheduler: + _target_: nemo.collections.multimodal.models.dreambooth.util.sd_noise_scheduler + parameterization: eps + v_posterior: 0 + given_betas: + beta_schedule: linear + timesteps: 1000 + linear_start: 0.00085 + linear_end: 0.012 + cosine_s: 8e-3 + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + optim: + name: fused_adam + lr: 1e-6 + weight_decay: 0. + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 1 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + data: + name: pbss + num_workers: 4 + instance_dir: /datasets/instance_dir + instance_prompt: a photo of a sks dog + regularization_dir: /datasets/nemo_dogs + regularization_prompt: a photo of a dog + num_reg_images: 10 + num_images_per_prompt: 4 + resolution: 512 + center_crop: True + cached_instance_dir: #/datasets/instance_dir_cached + cached_reg_dir: #/datasets/nemo_dogs_cached + +##The below infer config is to use inference script generating regularization images +infer: + unconditional_guidance_scale: 7.5 + num_images_per_prompt: ${model.data.num_images_per_prompt} + height: 512 + width: 512 + down_factor: 8 + inference_steps: 50 + sampler_type: 'PLMS' + eta: 0 + output_type: 'pil' + save_to_file: False + out_path: ${model.data.regularization_dir} + prompts: ${model.data.regularization_prompt} \ No newline at end of file diff --git a/examples/multimodal/text_to_image/dreambooth/conf/dreambooth_infer.yaml b/examples/multimodal/text_to_image/dreambooth/conf/dreambooth_infer.yaml new file mode 100644 index 000000000000..fc8d35443767 --- /dev/null +++ b/examples/multimodal/text_to_image/dreambooth/conf/dreambooth_infer.yaml @@ -0,0 +1,32 @@ +name: stable-diffusion-train + +infer: + unconditional_guidance_scale: 7.5 + num_images_per_prompt: 4 + height: 512 + width: 512 + down_factor: 8 + inference_steps: 100 + sampler_type: 'DDIM' + eta: 0 + output_type: 'pil' + save_to_file: True + out_path: 'dreambooth' + seed: 123 + prompts: + - 'a photo of a sks dog' + - 'a photo of a sks dog in the Acropolis' + - 'a photo of a sks dog in front of eiffel tower' + - 'a photo of sks dog sleeping' + - 'a photo of a sks dog riding a bike' + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + +model: + restore_from_path: null + precision: ${trainer.precision} \ No newline at end of file diff --git a/examples/multimodal/text_to_image/dreambooth/dreambooth.py b/examples/multimodal/text_to_image/dreambooth/dreambooth.py new file mode 100644 index 000000000000..d968d301389c --- /dev/null +++ b/examples/multimodal/text_to_image/dreambooth/dreambooth.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from omegaconf import OmegaConf + +from nemo.collections.multimodal.models.text_to_image.dreambooth import MegatronDreamBooth +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def prepare_reg_data(cfg): + reg_dir = cfg.model.data.regularization_dir + num_reg_images = cfg.model.data.num_reg_images + num_images_per_prompt = cfg.model.data.num_images_per_prompt + reg_prompt = cfg.model.data.regularization_prompt + os.makedirs(reg_dir, exist_ok=True) + NUM_REG_IMAGES = len(os.listdir(reg_dir)) + if NUM_REG_IMAGES < num_reg_images: + + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + model_cfg.unet_config.use_flash_attention = False + model_cfg.micro_batch_size = cfg.model.micro_batch_size + model_cfg.global_batch_size = cfg.model.global_batch_size + model_cfg.unet_config.from_pretrained = None + model_cfg.first_stage_config.from_pretrained = None + model_cfg.target = 'nemo.collections.multimodal.models.stable_diffusion.ldm.ddpm.MegatronLatentDiffusion' + + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronLatentDiffusion, cfg=cfg, model_cfg_modifier=model_cfg_modifier + ) + model = megatron_diffusion_model.model + rng = torch.Generator() + rng.manual_seed(trainer.global_rank * 100 + cfg.model.seed) + images_to_generate = cfg.model.data.num_reg_images - NUM_REG_IMAGES + images_to_generate = images_to_generate // trainer.world_size + + logging.info( + f"No enough images in regularization folder, generating {images_to_generate} from provided ckpt on each device" + ) + + for i in range(images_to_generate // num_images_per_prompt + 1): + output = pipeline(model, cfg, verbose=False, rng=rng) + for text_prompt, pils in zip(reg_prompt, output): + for idx, image in enumerate(pils): + image.save( + os.path.join( + cfg.infer.out_path, + f'{reg_prompt}_{trainer.global_rank}_{NUM_REG_IMAGES + i * num_images_per_prompt + idx}.png', + ) + ) + del model + del trainer + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +@hydra_runner(config_path='conf', config_name='dreambooth.yaml') +def main(cfg): + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + torch.backends.cuda.matmul.allow_tf32 = True + + if cfg.model.with_prior_preservation: + prepare_reg_data(cfg) + parallel_state.destroy_model_parallel() + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + + exp_manager(trainer, cfg.exp_manager) + + model = MegatronDreamBooth(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/dreambooth/dreambooth_infer.py b/examples/multimodal/text_to_image/dreambooth/dreambooth_infer.py new file mode 100644 index 000000000000..672431d7b3fa --- /dev/null +++ b/examples/multimodal/text_to_image/dreambooth/dreambooth_infer.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='dreambooth_infer') +def main(cfg): + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + model_cfg.unet_config.use_flash_attention = False + model_cfg.unet_config.from_pretrained = None + model_cfg.first_stage_config.from_pretrained = None + model_cfg.target = 'nemo.collections.multimodal.models.stable_diffusion.ldm.ddpm.MegatronLatentDiffusion' + + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronLatentDiffusion, cfg=cfg, model_cfg_modifier=model_cfg_modifier + ) + model = megatron_diffusion_model.model + model.cuda().eval() + + rng = torch.Generator().manual_seed(cfg.infer.seed) + pipeline(model, cfg, rng=rng) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/imagen/README.md b/examples/multimodal/text_to_image/imagen/README.md new file mode 100644 index 000000000000..ba33b649cb35 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/README.md @@ -0,0 +1,104 @@ +# Imagen +## A. Overview + +Imagen is a multi-stage text-to-image diffusion model with an unprecedented degree of photorealism and a deep level of language understanding. Given a text prompt, Imagen first generates an image at a 64x64 resolution and then upsamples the generated image to 256x256 and 1024x1024 resolutions, all using diffusion models. + +**Table of Contents:** +- [Imagen](#imagen) + - [A. Overview](#a-overview) + - [B. Imagen Pipeline](#b-imagen-pipeline) + - [C. Files in this folder](#c-files-in-this-folder) + - [D. Imagen Training](#d-imagen-training) + - [D.1 Training Dataset](#d1-training-dataset) + - [D.2 Training configs](#d2-training-configs) + - [E. Imagen Inference](#e-imagen-inference) + - [E.1 Inference Settings](#e1-inference-settings) + - [E.2 Running the sample inference code](#e2-running-the-sample-inference-code) + - [E.3 Inference GPU Memory Usage](#e3-inference-gpu-memory-usage) + - [E.3.1 FP16 Inference](#e31-fp16-inference) + - [E.3.2 FP32 Inference](#e32-fp32-inference) + - [E.3.3 AMP Inference (Autocast Enabled)](#e33-amp-inference-autocast-enabled) + - [F. UNet Architecture](#f-unet-architecture) + - [F.1 U-Net (used for base model)](#f1-u-net-used-for-base-model) + - [F.2 Efficient U-Net (used for SR models)](#f2-efficient-u-net-used-for-sr-models) + +## B. Imagen Pipeline + +Imagen comprises a frozen text encoder (e.g. T5-XXL) to map input text into a sequence of embeddings, and a 64x64 image diffusion model, followed by two super-resolution diffusion models for generating 256x256 and 1024x1024 images. All diffusion models are conditioned on the text embedding sequence and use classifier-free guidance. + +## C. Files in this folder + +- [imagen_training.py](imagen_training.py): Script for running inference +- [imagen_generate_images.py](imagen_generate_images.py): Script for generating images for FID-CLIP analysis +- [imagen_infer.py](imagen_infer.py): Script for running inference + +## D. Imagen Training + +All three diffusion models (64x64, 256x256, 1024x1024) can be trained independently. + +### D.1 Training Dataset + +### D.2 Training configs +| configs | Description | +|---|---| +| base64-2b.yaml | 2b-parameter base 64x64 model as described in Imagen paper | +| base64-500m.yaml | 500m-parameter base 64x64 model with decreased number of embedding channels| +|sr256-400m.yaml| 400m-parameter sr 256x256 model as described in Imagen paper | +|sr1024-400m.yaml| 400m-parameter sr 1024x1024 model as described in Imagen paper | + +## E. Imagen Inference + +### E.1 Inference Settings + +[inference_pipeline.yaml](conf/inference_pipeline.yaml) specifies every config for running the sample inference code. Specifically: +- num_images_per_promt: The number of images you want to generate for each text prompt +- model_name: Different pre-defined configs (not used for now) +- run_ema_model: Either run reg/ema model for pretrained models +- customized_model: Instead of loading pre-defined models, load specified checkpoint. .ckpt checkpoint (generated during in-the-middle of training) and .nemo checkpoint (generated once training completed) are both acceptable +- target_resolution: should be one of [64, 256, 1024] +- inference_precision: Running inference in one of [16, 32, AMP] mode +- dynamic_thresholding: Whether to use dynamic thresholding when generating images +- texts: List of text prompts that are used to generate images +- output_path: The path to save generate images +- encoder_path: If not set (null), it will download text encoder first time running the inference code (and will be saved to HF_HOME), you can also load it offline by setting it to the prepared folder +- samplers: List of sampler settings that are used for each model. `step` (the number of iterations to denoise the image, ideally the larger the better, but also consume more time) and `cfg` for classifier free guidance value. You can tweak these values for better visual quality. + +### E.2 Running the sample inference code +``` +(inside NeMo root folder) +python examples/multimodal/generative/imagen/imagen_infer.py +``` + +### E.3 Inference GPU Memory Usage + +#### E.3.1 FP16 Inference +| Output\Batch size | 1 | 8 | +|-------------------|-------|-------| +| 64x64 | 11.7G | 11.9G | +| 256x256 | 12.5G | 13.0G | +| 1024x1024 | 14.1G | 21.6G | + +#### E.3.2 FP32 Inference +| Output\Batch size | 1 | 8 | +|-------------------|-------|-------| +| 64x64 | 21.7G | 22.6G | +| 256x256 | 23.4G | 24.5G | +| 1024x1024 | 26.6G | 40.6G | + +#### E.3.3 AMP Inference (Autocast Enabled) +| Output\Batch size | 1 | 8 | +|-------------------|-------|-------| +| 64x64 | 22.4G | 23.4G | +| 256x256 | 24.0G | 25.1G | +| 1024x1024 | 26.4G | 33.7G | + +## F. UNet Architecture + +We have prepared two types of UNet for Imagen according to the paper. Base model (64x64) and SR models (256x256, 1024x1024) are using different UNet models. + +### F.1 U-Net (used for base model) + + + +### F.2 Efficient U-Net (used for SR models) + diff --git a/examples/multimodal/text_to_image/imagen/conf/base64-2b.yaml b/examples/multimodal/text_to_image/imagen/conf/base64-2b.yaml new file mode 100644 index 000000000000..4c02c97c9e4e --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/base64-2b.yaml @@ -0,0 +1,142 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-base64 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-base64-nf512 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + inductor: True + inductor_cudagraphs: False + unet_type: base + channels_last: True + + unet: + embed_dim: 512 + image_size: 64 + channels: 3 + num_res_blocks: 3 + channel_mult: [ 1, 2, 3, 4 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 2048 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [ 8, 16, 32 ] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + resblock_updown: False + resample_with_conv: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: True # True for using PyTorch default DDP overlap. False for using Megatron's default configuration for async grad allreduce + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + # If want to switch to continuous DDPM training, + # use the following config: + # preconditioning_type: DDPM + # preconditioning: + # loss_type: l2 + # pred_objective: noise + # noise_schedule: cosine + # timesteps: 1000 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 64 + center_crop_h_w: 64, 64 + horizontal_flip: False + filterings: null + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: False + local_root_path: /datasets + verbose: False + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/base64-500m-edm.yaml b/examples/multimodal/text_to_image/imagen/conf/base64-500m-edm.yaml new file mode 100644 index 000000000000..11224e3b84d2 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/base64-500m-edm.yaml @@ -0,0 +1,136 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-base64 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-base64-nf256 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 100 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 48 # limited by GPU memory + global_batch_size: 48 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + unet_type: base + + unet: + embed_dim: 256 + image_size: 64 + channels: 3 + num_res_blocks: 3 + channel_mult: [ 1, 2, 3, 4 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [ 8, 16, 32 ] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: False + resblock_updown: False + resample_with_conv: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 64 + center_crop_h_w: 64, 64 + horizontal_flip: False + filterings: null + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: False + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null diff --git a/examples/multimodal/text_to_image/imagen/conf/base64-500m.yaml b/examples/multimodal/text_to_image/imagen/conf/base64-500m.yaml new file mode 100644 index 000000000000..eb66b5b36feb --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/base64-500m.yaml @@ -0,0 +1,144 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + limit_val_batches: 0 + log_every_n_steps: 5 # Interval of logging. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-base64 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-base64-nf256 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 128 # limited by GPU memory + global_batch_size: 128 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + unet_type: base + channels_last: True + + unet: + embed_dim: 256 + image_size: 64 + channels: 3 + num_res_blocks: 3 + channel_mult: [ 1, 2, 3, 4 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [ 8, 16, 32 ] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + resblock_updown: False + resample_with_conv: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: False # True for using PyTorch default DDP overlap. False for using Megatron's default configuration for async grad allreduce + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + # If want to switch to continuous DDPM training, + # use the following config: + # preconditioning_type: DDPM + # preconditioning: + # loss_type: l2 + # pred_objective: noise + # noise_schedule: cosine + # timesteps: 1000 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + synthetic_data: False + synthetic_data_length: 800000 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 64 + center_crop_h_w: 64, 64 + horizontal_flip: False + filterings: null + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: False + local_root_path: /datasets + verbose: False + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null diff --git a/examples/multimodal/text_to_image/imagen/conf/base64-500m_online_encoding.yaml b/examples/multimodal/text_to_image/imagen/conf/base64-500m_online_encoding.yaml new file mode 100644 index 000000000000..efbab7bc1ca8 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/base64-500m_online_encoding.yaml @@ -0,0 +1,137 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-base64 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-base64-nf256 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 100 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 48 # limited by GPU memory + global_batch_size: 48 # will use more micro batches to reach global batch size + + unet_type: base + unet: + embed_dim: 256 + image_size: 64 + channels: 3 + num_res_blocks: 3 + channel_mult: [ 1, 2, 3, 4 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [ 8, 16, 32 ] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: True + flash_attention: False + resblock_updown: False + resample_with_conv: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + preconditioning_type: DDPM + preconditioning: + loss_type: l2 + pred_objective: noise + noise_schedule: cosine + timesteps: 1000 + + conditioning: + online_encoding: True # defaults to False (use precached encodings) if not specified + # Online encoding increases training time by about 3-4x, and is only for users who want to do a quick dev run of + # Imagen, and/or those who do not have the disk space to store precached embeddings. + # Optionally specify encoder_path if online_encoding; else, specify precached_key and out_key + encoder_path: # folder path to t5xxl-encoder.bin, or leave empty to download (and cache) t5-11b weights + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 64 + center_crop_h_w: 64, 64 + horizontal_flip: False + filterings: null + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: False + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/fid_inference.yaml b/examples/multimodal/text_to_image/imagen/conf/fid_inference.yaml new file mode 100644 index 000000000000..413da2b8eeac --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/fid_inference.yaml @@ -0,0 +1,26 @@ +num_images_per_promt: 8 # The number of images generated for each promt text +model_name: null # Avaliable model_name defined in pretrained_models.yaml +run_ema_model: True # Whether load the reg/ema model when using pretrained models +customized_model: # Mutually exclusive with model_name + base_ckpt: /aot/exp/nemo-megatron-stacked-ddpm-16n/imagen-nemo/checkpoints/imagen-nemo--reduced_train_loss=0.03-step=100000-consumed_samples=512000000.0.ckpt # Either .ckpt or .nemo is accepatable + base_cfg: examples/multimodal/generative/imagen/conf/base64-500m.yaml # Must provided if loading .ckpt checkpoint + sr256_ckpt: null + sr256_cfg: examples/multimodal/generative/imagen/conf/sr256-400m.yaml + sr1024_ckpt: null + sr1024_cfg: null +target_resolution: 64 # in [64, 256, 1024] +inference_precision: '32' # [16, 32, AMP] +thresholding_method: 'dynamic' +output_path: 'output/imagen-megatron-pipeline-fid' # Save location +record_time: True # Whether to record inference time meta +encoder_path: '/ckpts/encoders' # Set to null if you wish to download encoders on the fly +samplings: + - + step: 250 + cfg: 7.5 + - + step: 20 + cfg: 7.5 + + + diff --git a/examples/multimodal/text_to_image/imagen/conf/imagen_fid_images.yaml b/examples/multimodal/text_to_image/imagen/conf/imagen_fid_images.yaml new file mode 100644 index 000000000000..5a5867cfae50 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/imagen_fid_images.yaml @@ -0,0 +1,57 @@ +name: imagen_fid_images + +fid: + classifier_free_guidance: + - 1 + - 1.5 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + nnodes_per_cfg: 1 + ntasks_per_node: 8 + local_task_id: null + num_images_to_eval: 30000 + coco_captions_path: /aot/datasets/coco2014/coco2014_val_sampled_30k/captions + coco_images_path: /aot/datasets/coco2014/coco2014_val/images_256 + save_path: output/fid-launcher-test + ncaptions_per_batch: 4 + save_all_res: False + save_text: False + +infer: + num_images_per_promt: 1 # The number of images generated for each promt text + model_name: null # Avaliable model_name defined in pretrained_models.yaml + run_ema_model: True # Whether load the reg/ema model when using pretrained models + customized_model: # Mutually exclusive with model_name + base_ckpt: /aot/exp/ckpts/imagen-megatron/edm-fused-1150k-ema.nemo # Either .ckpt or .nemo is accepatable + base_cfg: null # Must provided if loading .ckpt checkpoint + sr256_ckpt: /aot/exp/ckpts/imagen-megatron/sr-noise-aug-280k.nemo + sr256_cfg: null + sr1024_ckpt: null + sr1024_cfg: null + target_resolution: 256 # in [64, 256, 1024] + inference_precision: '32' # [16, 32, AMP] + thresholding_method: 'dynamic' + record_time: True # Whether to record inference time meta + encoder_path: '/ckpts/encoders' # Set to null if you wish to download encoders on the fly + samplings: + - + step: 30 + - + step: 20 + +models: + - + restore_from_path: /aot/exp/ckpts/imagen-megatron/edm-fused-1150k-ema.nemo + - + restore_from_path: /aot/exp/ckpts/imagen-megatron/sr-noise-aug-280k.nemo + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager diff --git a/examples/multimodal/text_to_image/imagen/conf/inference_pipeline.yaml b/examples/multimodal/text_to_image/imagen/conf/inference_pipeline.yaml new file mode 100644 index 000000000000..1b4bbd9e5a17 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/inference_pipeline.yaml @@ -0,0 +1,42 @@ +num_images_per_promt: 4 # The number of images generated for each promt text +model_name: null # Avaliable model_name defined in pretrained_models.yaml +run_ema_model: True # Whether load the reg/ema model when using pretrained models +customized_model: # Mutually exclusive with model_name + base_ckpt: null # Either .ckpt or .nemo is accepatable + base_cfg: examples/multimodal/generative/imagen/conf/base64-500m.yaml # Must provided if loading .ckpt checkpoint + sr256_ckpt: null + sr256_cfg: examples/multimodal/generative/imagen/conf/sr256-400m.yaml + sr1024_ckpt: null + sr1024_cfg: examples/multimodal/generative/imagen/conf/sr1024-400m.yaml +target_resolution: 64 # in [64, 256, 1024] +inference_precision: 32 # [16, 32, AMP] +thresholding_method: dynamic +texts: + - 'a photograph of an astronaut riding a horse' + - 'a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal' + - A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat. + - A cute corgi lives in a house made out of sushi. + - A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him. + - A brain riding a rocketship heading towards the moon. + - One cat and two dogs sitting on the grass. + - A wine glass on top of a dog. + - A blue coloured pizza. + - A transparent sculpture of a duck made out of glass. There is a painting on the wall behind it. + - A raccoon wearing cowboy hat and black leather jacket is behind the backyard window. Rain droplets on the window. + +output_path: 'output/imagen_output' # Save location +record_time: True # Whether to record inference time meta +encoder_path: '/ckpts/encoders' # Set to null if you wish to download encoders on the fly +samplings: + - # Base64 + step: 30 + cfg: 7.5 + - # SR256 + step: 20 + cfg: 8 + - # SR1024 + step: 20 + cfg: 7.5 + + + diff --git a/examples/multimodal/text_to_image/imagen/conf/sr1024-600m.yaml b/examples/multimodal/text_to_image/imagen/conf/sr1024-600m.yaml new file mode 100644 index 000000000000..3652267193b1 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr1024-600m.yaml @@ -0,0 +1,145 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-1024 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr1024-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 64 # limited by GPU memory + global_batch_size: 64 # will use more micro batches to reach global batch size + inductor: True + inductor_cudagraphs: False + unet_type: sr + channels_last: True + + unet: + embed_dim: 128 + image_size: 1024 + channels: 3 + channel_mult: [ 1, 2, 4, 8, 8 ] + num_attn_heads: 8 + per_head_channels: 64 + attention_type: cross + atnn_enabled_at: [ 0, 0, 0, 1, 1 ] + feature_pooling_type: attention + stride: 2 + num_resblocks: [ 2, 4, 8, 8, 8 ] + learned_sinu_pos_emb_dim: 0 + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: True + flash_attention: False + skip_connection_scaling: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: True # True for using PyTorch default DDP overlap. False for using Megatron's default configuration for async grad allreduce + + noise_cond_aug: True + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + # If want to switch to continuous DDPM training, + # use the following config: + # preconditioning_type: DDPM + # preconditioning: + # loss_type: l2 + # pred_objective: noise + # noise_schedule: cosine + # timesteps: 1000 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 1024 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 1024 + estimated_portion: 0.2 # Estimated % of examples left after filtering. This is use to estimate # epoch + target_resolutions: [64, 256] + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-400m-edm.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-400m-edm.yaml new file mode 100644 index 000000000000..22ab0672e577 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-400m-edm.yaml @@ -0,0 +1,222 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + + unet_type: sr-unet + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + num_res_blocks: [2, 2, 3, 4, 3] + channel_mult: [ 1, 2, 4, 6, 6 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [32, 16] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + resblock_updown: False + resample_with_conv: True + low_res_cond: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + # - datasets/improved-aesthetic/wdinfo-selene.pkl + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + corruption_aug: + target_resolution: [ 64, 256 ] + kernel_radius_dict: # used for blurring & resizing, otherwise, not necessary. + 8: 1 + 16: 2 + 32: 3 + 64: 6 + 128: 11 + 256: 22 + 512: 44 + 1024: 88 + 2048: 176 + 4096: 352 + + blur: + add_random_blur: True + blur_prob1: 0.2 + blur_prob2: 0.2 + + blur_sigma_dict: + 8: 0.25 + 16: 0.5 + 32: 0.75 + 64: 1.5 + 128: 3 + 256: 6 + 512: 12 + 1024: 24 + 2048: 48 + 4096: 96 + + resize: + add_random_resize: True + + resize_prob1: + up: 0.2 + down: 0.2 + keep: 0.6 + resize_prob2: + up: 0.2 + down: 0.2 + keep: 0.6 + + resize_range1: + - 0.8 + - 1.2 + resize_range2: + - 0.8 + - 1.2 + + noise: + add_random_noise: True + gaussian_noise_prob1: 1.0 # 0.5 + gaussian_noise_prob2: 1.0 # 0.5 + gray_noise_prob1: 0.0 # 0.4 + gray_noise_prob2: 0.0 # 0.4 + + gaussian_sigma_range1: + - 0 + - 3 + gaussian_sigma_range2: + - 0 + - 2.5 + + poisson_scale_range1: + - 0.005 + - 3 + poisson_scale_range2: + - 0.005 + - 2.5 + + jpeg: + add_random_compression: False + jpeg_range1: + - 75 + - 95 + jpeg_range2: + - 75 + - 95 + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-400m.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-400m.yaml new file mode 100644 index 000000000000..984bddda2c55 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-400m.yaml @@ -0,0 +1,150 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size + inductor: True + inductor_cudagraphs: False + channels_last: True + + unet_type: sr-unet + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + num_res_blocks: [2, 2, 3, 4, 3] + channel_mult: [ 1, 2, 4, 6, 6 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: fused + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [32, 16] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + resblock_updown: False + resample_with_conv: True + low_res_cond: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: True # True for using PyTorch default DDP overlap. False for using Megatron's default configuration for async grad allreduce + + noise_cond_aug: True + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + # If want to switch to continuous DDPM training, + # use the following config: + # preconditioning_type: DDPM + # preconditioning: + # loss_type: l2 + # pred_objective: noise + # noise_schedule: cosine + # timesteps: 1000 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + target_resolutions: [ 64, 256 ] + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-450m-edm.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-450m-edm.yaml new file mode 100644 index 000000000000..cbee92a40a58 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-450m-edm.yaml @@ -0,0 +1,222 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 16 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + + unet_type: sr-unet + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + num_res_blocks: [2, 2, 3, 4, 3] + channel_mult: [ 1, 2, 4, 6, 6 ] + num_attn_heads: 4 + per_head_channels: 64 + cond_dim: 512 + attention_type: stacked + feature_pooling_type: attention + learned_sinu_pos_emb_dim: 0 + attention_resolutions: [32, 16] + dropout: False + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + resblock_updown: False + resample_with_conv: True + low_res_cond: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + # - datasets/improved-aesthetic/wdinfo-selene.pkl + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + corruption_aug: + target_resolution: [ 64, 256 ] + kernel_radius_dict: # used for blurring & resizing, otherwise, not necessary. + 8: 1 + 16: 2 + 32: 3 + 64: 6 + 128: 11 + 256: 22 + 512: 44 + 1024: 88 + 2048: 176 + 4096: 352 + + blur: + add_random_blur: True + blur_prob1: 0.2 + blur_prob2: 0.2 + + blur_sigma_dict: + 8: 0.25 + 16: 0.5 + 32: 0.75 + 64: 1.5 + 128: 3 + 256: 6 + 512: 12 + 1024: 24 + 2048: 48 + 4096: 96 + + resize: + add_random_resize: True + + resize_prob1: + up: 0.2 + down: 0.2 + keep: 0.6 + resize_prob2: + up: 0.2 + down: 0.2 + keep: 0.6 + + resize_range1: + - 0.8 + - 1.2 + resize_range2: + - 0.8 + - 1.2 + + noise: + add_random_noise: True + gaussian_noise_prob1: 1.0 # 0.5 + gaussian_noise_prob2: 1.0 # 0.5 + gray_noise_prob1: 0.0 # 0.4 + gray_noise_prob2: 0.0 # 0.4 + + gaussian_sigma_range1: + - 0 + - 3 + gaussian_sigma_range2: + - 0 + - 2.5 + + poisson_scale_range1: + - 0.005 + - 3 + poisson_scale_range2: + - 0.005 + - 2.5 + + jpeg: + add_random_compression: False + jpeg_range1: + - 75 + - 95 + jpeg_range2: + - 75 + - 95 + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm-noise.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm-noise.yaml new file mode 100644 index 000000000000..3e5318186961 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm-noise.yaml @@ -0,0 +1,142 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + + unet_type: sr + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + channel_mult: [ 1, 2, 4, 8, 8 ] + num_attn_heads: 8 + per_head_channels: 64 + attention_type: stacked + atnn_enabled_at: [ 0, 0, 0, 1, 1 ] + feature_pooling_type: attention + stride: 2 + num_resblocks: [ 2, 4, 8, 8, 8 ] + learned_sinu_pos_emb_dim: 0 + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: False + skip_connection_scaling: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + noise_cond_aug: True + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + corruption_aug: + target_resolution: [ 64, 256 ] + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm.yaml new file mode 100644 index 000000000000..67f05c52ff6e --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-600m-edm.yaml @@ -0,0 +1,219 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + inductor: False + inductor_cudagraphs: False + + unet_type: sr + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + channel_mult: [ 1, 2, 4, 8, 8 ] + num_attn_heads: 8 + per_head_channels: 64 + attention_type: stacked + atnn_enabled_at: [ 0, 0, 0, 1, 1 ] + feature_pooling_type: attention + stride: 2 + num_resblocks: [ 2, 4, 8, 8, 8 ] + learned_sinu_pos_emb_dim: 0 + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: False + skip_connection_scaling: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + # - datasets/improved-aesthetic/wdinfo-selene.pkl + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + corruption_aug: + target_resolution: [ 64, 256 ] + kernel_radius_dict: # used for blurring & resizing, otherwise, not necessary. + 8: 1 + 16: 2 + 32: 3 + 64: 6 + 128: 11 + 256: 22 + 512: 44 + 1024: 88 + 2048: 176 + 4096: 352 + + blur: + add_random_blur: True + blur_prob1: 0.2 + blur_prob2: 0.2 + + blur_sigma_dict: + 8: 0.25 + 16: 0.5 + 32: 0.75 + 64: 1.5 + 128: 3 + 256: 6 + 512: 12 + 1024: 24 + 2048: 48 + 4096: 96 + + resize: + add_random_resize: True + + resize_prob1: + up: 0.2 + down: 0.2 + keep: 0.6 + resize_prob2: + up: 0.2 + down: 0.2 + keep: 0.6 + + resize_range1: + - 0.8 + - 1.2 + resize_range2: + - 0.8 + - 1.2 + + noise: + add_random_noise: True + gaussian_noise_prob1: 1.0 # 0.5 + gaussian_noise_prob2: 1.0 # 0.5 + gray_noise_prob1: 0.0 # 0.4 + gray_noise_prob2: 0.0 # 0.4 + + gaussian_sigma_range1: + - 0 + - 3 + gaussian_sigma_range2: + - 0 + - 2.5 + + poisson_scale_range1: + - 0.005 + - 3 + poisson_scale_range2: + - 0.005 + - 2.5 + + jpeg: + add_random_compression: False + jpeg_range1: + - 75 + - 95 + jpeg_range2: + - 75 + - 95 + + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + pbss_checkpoint_saving: + enable: False + pbss_credentials_file: pbss_credentials_joc.secret + save_frequency: 1000 + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/conf/sr256-600m.yaml b/examples/multimodal/text_to_image/imagen/conf/sr256-600m.yaml new file mode 100644 index 000000000000..115e9dd3099c --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/conf/sr256-600m.yaml @@ -0,0 +1,146 @@ +name: imagen-nemo # The name of your model +allow_tf32: True + +trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + + +exp_manager: + exp_dir: /train/imagen-256 # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: imagen-sr256-nf128 + project: imagen + group: nemo-imagen + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 64 # limited by GPU memory + global_batch_size: 64 # will use more micro batches to reach global batch size + inductor: True + inductor_cudagraphs: False + channels_last: True + + unet_type: sr + unet: + embed_dim: 128 + image_size: 256 + channels: 3 + channel_mult: [ 1, 2, 4, 8, 8 ] + num_attn_heads: 8 + per_head_channels: 64 + attention_type: fused + atnn_enabled_at: [ 0, 0, 0, 1, 1 ] + feature_pooling_type: attention + stride: 2 + num_resblocks: [ 2, 4, 8, 8, 8 ] + learned_sinu_pos_emb_dim: 0 + use_null_token: False + init_conv_kernel_size: 3 + gradient_checkpointing: False + scale_shift_norm: True + stable_attention: False + flash_attention: True + skip_connection_scaling: True + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: True # True for using PyTorch default DDP overlap. False for using Megatron's default configuration for async grad allreduce + + noise_cond_aug: True + preconditioning_type: EDM + preconditioning: + loss_type: l2 + sigma_data: 0.5 + p_mean: -1.2 + p_std: 1.2 + # If want to switch to continuous DDPM training, + # use the following config: + # preconditioning_type: DDPM + # preconditioning: + # loss_type: l2 + # pred_objective: noise + # noise_schedule: cosine + # timesteps: 1000 + + conditioning: + embed_dim: 1024 + token_length: 128 + drop_rate: 0.1 + precached_key: embeddings_t5_xxl + out_key: t5_text + + data: + num_workers: 16 + train: + dataset_path: + - datasets/laion_aesthetic/wdinfo-selene.pkl # 48,874,000 + - datasets/coyo-700m/wdinfo-selene.pkl # 627,172,000 + augmentations: + resize_smallest_side: 256 + center_crop_h_w: 256, 256 + horizontal_flip: False + filterings: + resolution: + method: larger + value: 256 + estimated_portion: 0.8 # Estimated % of examples left after filtering. This is use to estimate # epoch + target_resolutions: [64, 256] + webdataset: + use_webdataset: True + object_store: False + infinite_sampler: True + local_root_path: /datasets + verbose: False + + optim: + # We need weight decay for large-scale odel + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null \ No newline at end of file diff --git a/examples/multimodal/text_to_image/imagen/generate_fid_images.py b/examples/multimodal/text_to_image/imagen/generate_fid_images.py new file mode 100644 index 000000000000..ea743e3e1d06 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/generate_fid_images.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from pytorch_lightning import Trainer + +from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ImagenPipeline +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='imagen_fid_images') +def main(cfg): + # Read configuration parameters + nnodes_per_cfg = cfg.fid.nnodes_per_cfg + ntasks_per_node = cfg.fid.ntasks_per_node + local_task_id = cfg.fid.local_task_id + num_images_to_eval = cfg.fid.num_images_to_eval + path = cfg.fid.coco_captions_path + save_text = cfg.fid.save_text + + node_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) + node_id_per_cfg = node_id % nnodes_per_cfg + + current_node_cfg = cfg.fid.classifier_free_guidance[node_id // nnodes_per_cfg] + save_path = os.path.join(cfg.fid.save_path, str(current_node_cfg)) + + # Read and store captions + captions = [] + caption_files = sorted(os.listdir(path)) + assert len(caption_files) >= num_images_to_eval + for file in caption_files[:num_images_to_eval]: + with open(os.path.join(path, file), 'r') as f: + captions += f.readlines() + print(f"The total number of captions to generate is: {len(captions)}") + + # Calculate partition sizes and select the partition for the current node + partition_size_per_node = num_images_to_eval // nnodes_per_cfg + start_idx = node_id_per_cfg * partition_size_per_node + end_idx = (node_id_per_cfg + 1) * partition_size_per_node if node_id_per_cfg != nnodes_per_cfg - 1 else None + captions = captions[start_idx:end_idx] + print(f"Current node {node_id} will generate images from {start_idx} to {end_idx}") + + local_task_id = int(local_task_id) if local_task_id is not None else int(os.environ.get("SLURM_LOCALID", 0)) + partition_size_per_task = int(len(captions) // ntasks_per_node) + + # Select the partition for the current task + start_idx = local_task_id * partition_size_per_task + end_idx = (local_task_id + 1) * partition_size_per_task if local_task_id != ntasks_per_node - 1 else None + input = captions[start_idx:end_idx] + chunk_size = len(input) + + print(f"Current worker {node_id}:{local_task_id} will generate {len(input)} images") + os.makedirs(save_path, exist_ok=True) + + trainer = Trainer() + pipeline = ImagenPipeline.from_pretrained(cfg=cfg.infer, trainer=trainer, megatron_loading=True, megatron_cfg=cfg) + + # Generate images using the model and save them + batch_idx = 0 + batch_size = cfg.fid.ncaptions_per_batch + while True: + if batch_idx * batch_size >= len(input): + break + batch_captions = input[batch_idx * batch_size : (batch_idx + 1) * batch_size] + # Different seed for every image + seeds = [local_task_id * chunk_size + batch_idx * batch_size + idx for idx in range(len(batch_captions))] + with torch.no_grad(): + images, all_res_images, *_ = pipeline( + prompts=batch_captions, seed=seeds, single_batch_mode=True, classifier_free_guidance=current_node_cfg, + ) + + if cfg.fid.save_all_res: + all_res = [f'_RES{model.image_size}' for model in pipeline.models] + outpaths = [] + # for the highest resolution we save as its original name so that + # we can automate the CLIP & FID calculation process from Megatron-Launcher + all_res[-1] = '' + for res in all_res: + outpath = f"{save_path}{res}" + os.makedirs(outpath, exist_ok=True) + outpaths.append(outpath) + for outpath, one_res in zip(outpaths, all_res_images): + for idx, (caption, image) in enumerate(zip(batch_captions, one_res[0])): + image_idx = local_task_id * chunk_size + batch_idx * batch_size + idx + image.save(os.path.join(outpath, f'image{image_idx:06d}.png')) + if save_text: + with open(os.path.join(outpath, f'image{image_idx:06d}.txt'), 'w') as f: + f.writelines(caption) + else: + for idx, (caption, image) in enumerate(zip(batch_captions, images[0])): + image_idx = local_task_id * chunk_size + batch_idx * batch_size + idx + image.save(os.path.join(save_path, f'image{image_idx:06d}.png')) + if save_text: + with open(os.path.join(save_path, f'image{image_idx:06d}.txt'), 'w') as f: + f.writelines(caption) + print( + f'Save {len(images[0])} images to {save_path} with name from image{(local_task_id*chunk_size+batch_idx*batch_size):06d}.png to image{image_idx:06d}.png' + ) + batch_idx += 1 + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/imagen/imagen_generate_images.py b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py new file mode 100644 index 000000000000..bc002052a989 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle + +import torch +from omegaconf import OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( + ImagenPipeline, + ImagenPipelineConfig, +) +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='fid_inference.yaml') +def main(inference_config): + inference_config: ImagenPipelineConfig = OmegaConf.merge(ImagenPipelineConfig(), inference_config) + captions = pickle.load(open('coco_captions.pkl', 'rb')) + ntasks = 8 + if os.environ.get('CUDA_VISIBLE_DEVICES'): + # Multi-GPU + task_id = int(os.environ.get("CUDA_VISIBLE_DEVICES", 0)) + else: + # Single GPU + task_id = 0 + chuncksize = int(len(captions) // ntasks) + if task_id != ntasks - 1: + input = captions[task_id * chuncksize : (task_id + 1) * chuncksize] + else: + input = captions[task_id * chuncksize :] + captions = input + + trainer = Trainer() + pipeline = ImagenPipeline.from_pretrained(cfg=inference_config, trainer=trainer) + batch_size = 16 + batch_idx = 0 + + possible_res = [64, 256] # [64, 256] + outpaths = [] + for res in possible_res: + outpath = f'{inference_config.output_path}_RES{res}' + os.makedirs(outpath, exist_ok=True) + outpaths.append(outpath) + while True: + if batch_idx * batch_size >= len(captions): + break + batch_captions = captions[batch_idx * batch_size : (batch_idx + 1) * batch_size] + + # Different seed for every image + seeds = [task_id * chuncksize + batch_idx * batch_size + idx for idx in range(len(batch_captions))] + seed = batch_idx + chuncksize + + with torch.no_grad(): + images, all_res_images, throughput = pipeline(prompts=batch_captions, seed=seeds, single_batch_mode=True,) + + for outpath, one_res in zip(outpaths, all_res_images): + for idx, (caption, image) in enumerate(zip(batch_captions, one_res[0])): + image.save(os.path.join(outpath, f'image_{task_id*chuncksize+batch_idx*batch_size+idx}.png')) + with open(os.path.join(outpath, f'image_{task_id*chuncksize+batch_idx*batch_size+idx}.txt'), 'w') as f: + f.writelines(caption) + batch_idx += 1 + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/imagen/imagen_infer.py b/examples/multimodal/text_to_image/imagen/imagen_infer.py new file mode 100644 index 000000000000..0fb291729596 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/imagen_infer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from omegaconf import OmegaConf +from pytorch_lightning import Trainer + +from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( + ImagenPipeline, + ImagenPipelineConfig, +) +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='inference_pipeline.yaml') +def main(inference_config): + if inference_config.get('infer'): + # invoking from launcher + trainer = Trainer(**inference_config.trainer) + inference_config = inference_config.infer + else: + trainer = Trainer() + inference_config: ImagenPipelineConfig = OmegaConf.merge(ImagenPipelineConfig(), inference_config) + pipeline = ImagenPipeline.from_pretrained(cfg=inference_config, trainer=trainer) + + # Texts are passed in the config files + images, all_res, throughput = pipeline() + + # Save images + outpath = inference_config.output_path + os.makedirs(outpath, exist_ok=True) + for text, pils in zip(inference_config.texts, images): + for idx, image in enumerate(pils): + image.save(os.path.join(outpath, f'{text}_{idx}.png')) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/imagen/imagen_training.py b/examples/multimodal/text_to_image/imagen/imagen_training.py new file mode 100644 index 000000000000..61e879ebb063 --- /dev/null +++ b/examples/multimodal/text_to_image/imagen/imagen_training.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +import torch +from omegaconf.omegaconf import OmegaConf, open_dict +from torch._dynamo import disable +from torch._inductor import config as inductor_config + +from nemo.collections.multimodal.models.text_to_image.imagen import MegatronImagen +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path='conf', config_name='base64-500m') +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + model = MegatronImagen(cfg.model, trainer) + + if cfg.model.get("inductor", False): + # Temporary hack to get rid of TorchDynamo issue with DDP + # TODO: remove these if https://github.com/pytorch/pytorch/issues/94574 fixed + torch.arange = disable(torch.arange) + torch.ones = disable(torch.ones) + torch.zeros = disable(torch.zeros) + + # TODO: remove this if latest TorchDynamo fixed `t.uniform_(0, 1)` failure + torch.Tensor.uniform_ = disable(torch.Tensor.uniform_) + + # Disable TorchDynamo for unsupported function + pl.core.LightningModule.log = disable(pl.core.LightningModule.log) + + # TorchInductor with CUDA graph can lead to OOM + inductor_config.triton.cudagraphs = cfg.model.inductor_cudagraphs + model.model.model.unet = torch.compile(model.model.model.unet) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_edit.yaml b/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_edit.yaml new file mode 100644 index 000000000000..75eed9d9b6bf --- /dev/null +++ b/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_edit.yaml @@ -0,0 +1,23 @@ +edit: + resolution: 256 + steps: 100 + input: path/to/input/picture + outpath: path/to/output/folder + prompt: "" + cfg_text: 7.5 + cfg_image: 1.2 + num_images_per_prompt: 8 + combine_images: [ 2, 4 ] # [row, column] + seed: 1234 + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: 16 # 16, 32, or bf16 + +model: + restore_from_path: null # Path to a trained instruct pix2pix .nemo file + precision: ${trainer.precision} + diff --git a/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_finetune.yaml b/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_finetune.yaml new file mode 100644 index 000000000000..34ef1f436cd6 --- /dev/null +++ b/examples/multimodal/text_to_image/instruct_pix2pix/conf/sd_finetune.yaml @@ -0,0 +1,168 @@ +name: instruct-pix2pix-train + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 1 + val_check_interval: 100 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: instruct-pix2pix + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + create_checkpoint_callback: True + create_tensorboard_logger: True + checkpoint_callback_params: + save_top_k: 4 + mode: min + monitor: val/loss + filename: 'instruct-pix2pix--{val/loss:.4f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + ckpt_path: null # load checkpoint weights from previous stages for fine-tuning + precision: ${trainer.precision} + micro_batch_size: 32 + global_batch_size: 32 # `= micro_batch_size * total_devices` fake global batch size for sampler + + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit # txt for cifar, caption for pbss + image_size: 32 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + scale_by_std: False + + ignore_keys: [ ] + parameterization: eps + clip_denoised: True + load_only_unet: False + cosine_s: 8e-3 + given_betas: + original_elbo_weight: 0 + v_posterior: 0 + l_simple_weight: 1 + use_positional_encodings: False + learn_logvar: False + logvar_init: 0 + beta_schedule: linear + loss_type: l2 + concat_mode: True + cond_stage_forward: + text_embedding_dropout_rate: 0 + fused_opt: True + inductor: False + inductor_cudagraphs: False + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + use_flash_attention: False + + first_stage_config: + _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKL + from_pretrained: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + version: openai/clip-vit-large-patch14 + device: cuda + max_length: 77 + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0. + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 100 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + data: + # Path to instruct-pix2pix dataset must be specified by the user. + # https://github.com/timothybrooks/instruct-pix2pix#generated-dataset + data_path: ??? + num_workers: 2 + dataloader_type: cyclic # cyclic + validation_drop_last: True # Set to false if the last partial validation samples is to be consumed diff --git a/examples/multimodal/text_to_image/instruct_pix2pix/sd_edit_cli.py b/examples/multimodal/text_to_image/instruct_pix2pix/sd_edit_cli.py new file mode 100644 index 000000000000..f33540601848 --- /dev/null +++ b/examples/multimodal/text_to_image/instruct_pix2pix/sd_edit_cli.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import os +import random + +import einops +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import OmegaConf, open_dict +from PIL import Image, ImageOps + +from nemo.collections.multimodal.models.text_to_image.instruct_pix2pix.ldm.ddpm_edit import MegatronLatentDiffusionEdit +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.k_diffusion import ( + DiscreteEpsDDPMDenoiser, + sample_euler_ancestral, +) +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.core.config import hydra_runner +from nemo.utils import logging + + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): + cfg_z = einops.repeat(z, "b ... -> (n b) ...", n=3) + cfg_sigma = einops.repeat(sigma, "b ... -> (n b) ...", n=3) + cfg_cond = { + "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], + "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], + } + out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) + out = out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) + return out + + +@hydra_runner(config_path='conf', config_name='sd_edit') +def main(cfg): + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + with open_dict(cfg): + edit_cfg = cfg.pop("edit") + + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronLatentDiffusionEdit, cfg=cfg, model_cfg_modifier=model_cfg_modifier, + ) + + # inference use the latent diffusion part of megatron wrapper + model = megatron_diffusion_model.model + model_wrap = DiscreteEpsDDPMDenoiser(model) + model_wrap_cfg = CFGDenoiser(model_wrap) + null_token = model.get_learned_conditioning([""]) + + seed = random.randint(0, 100000) if edit_cfg.seed is None else edit_cfg.seed + input_image = Image.open(edit_cfg.input).convert("RGB") + width, height = input_image.size + factor = edit_cfg.resolution / max(width, height) + factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) + width = int((width * factor) // 64) * 64 + height = int((height * factor) // 64) * 64 + input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) + + if edit_cfg.prompt == "": + input_image.save(edit_cfg.output) + return + + # get autocast_dtype + if trainer.precision in ['bf16', 'bf16-mixed']: + autocast_dtype = torch.bfloat16 + elif trainer.precision in [32, '32', '32-true']: + autocast_dtype = torch.float + elif trainer.precision in [16, '16', '16-mixed']: + autocast_dtype = torch.half + else: + raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + + num_images_per_prompt = edit_cfg.num_images_per_prompt + with torch.no_grad(), torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + ): + cond = {} + cond["c_crossattn"] = [ + repeat(model.get_learned_conditioning([edit_cfg.prompt]), "1 ... -> n ...", n=num_images_per_prompt) + ] + input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 + input_image = rearrange(input_image, "h w c -> 1 c h w").cuda(non_blocking=True) + cond["c_concat"] = [ + repeat(model.encode_first_stage(input_image).mode(), "1 ... -> n ...", n=num_images_per_prompt) + ] + + uncond = {} + uncond["c_crossattn"] = [repeat(null_token, "1 ... -> n ...", n=num_images_per_prompt)] + uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] + + sigmas = model_wrap.get_sigmas(edit_cfg.steps) + + extra_args = { + "cond": cond, + "uncond": uncond, + "text_cfg_scale": edit_cfg.cfg_text, + "image_cfg_scale": edit_cfg.cfg_image, + } + torch.manual_seed(seed) + z = torch.randn_like(cond["c_concat"][0]) + z = z * sigmas[0] + z = sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) + x = model.decode_first_stage(z) + x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) + x = 255.0 * rearrange(x, "n c h w -> n h w c") + + os.makedirs(edit_cfg.outpath, exist_ok=True) + if edit_cfg.get("combine_images") is None: + for idx, image in enumerate(x): + edited_image = Image.fromarray(image.type(torch.uint8).cpu().numpy()) + save_path = os.path.join( + edit_cfg.outpath, + f'{edit_cfg.prompt.replace(" ", "_")}_{edit_cfg.cfg_text}_{edit_cfg.cfg_image}_{seed}_{idx}.jpg', + ) + edited_image.save(save_path) + logging.info(f"Edited image saved to: {save_path}") + else: + row, column = edit_cfg.combine_images + width, height = x.size(2), x.size(1) + total_width, total_height = width * column, height * row + edited_image = Image.new('RGB', (total_width, total_height)) + x_offset = 0 + y_offset = 0 + for idx, image in enumerate(x): + image = Image.fromarray(image.type(torch.uint8).cpu().numpy()) + edited_image.paste(image, (x_offset, y_offset)) + x_offset += image.size[0] + if (idx + 1) % column == 0: + x_offset = 0 + y_offset += height + save_path = os.path.join( + edit_cfg.outpath, + f'{edit_cfg.prompt.replace(" ", "_")}_{edit_cfg.cfg_text}_{edit_cfg.cfg_image}_{seed}_combine.jpg', + ) + edited_image.save(save_path) + logging.info(f"Edited image saved to: {save_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/instruct_pix2pix/sd_finetune.py b/examples/multimodal/text_to_image/instruct_pix2pix/sd_finetune.py new file mode 100644 index 000000000000..c7244de6113f --- /dev/null +++ b/examples/multimodal/text_to_image/instruct_pix2pix/sd_finetune.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.multimodal.models.text_to_image.instruct_pix2pix.ldm.ddpm_edit import MegatronLatentDiffusionEdit +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="sd_finetune") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + model = MegatronLatentDiffusionEdit(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml new file mode 100644 index 000000000000..3cfc822f8462 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml @@ -0,0 +1,192 @@ +name: stable-diffusion2-train + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 140000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: stable-diffusion + group: nemo-sd + name: ${name} + resume: True + create_checkpoint_callback: True + create_tensorboard_logger: True + checkpoint_callback_params: + every_n_train_steps: 1000 + every_n_epochs: 0 + monitor: reduced_train_loss + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 4 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size + + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: images + cond_stage_key: captions # txt for cifar, caption for pbss + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn # check + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + scale_by_std: False + ckpt_path: + ignore_keys: [] + parameterization: eps + clip_denoised: True + load_only_unet: False + cosine_s: 8e-3 + given_betas: + original_elbo_weight: 0 + v_posterior: 0 + l_simple_weight: 1 + use_positional_encodings: False + learn_logvar: False + logvar_init: 0 + beta_schedule: linear + loss_type: l2 + + concat_mode: True + cond_stage_forward: + text_embedding_dropout_rate: 0.1 + fused_opt: True + inductor: True + inductor_cudagraphs: False + capture_cudagraph_iters: -1 # -1 to disable + channels_last: True + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: + from_NeMo: #Must be specified when from pretrained is not None, False means loading unet from HF ckpt + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_head_channels: 64 + use_spatial_transformer: true + use_linear_in_transformer: true + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + use_flash_attention: False + + first_stage_config: + _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKL + from_pretrained: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 #Never used + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder + restore_from_path: /path/to/clip.nemo + device: cuda + freeze: True + layer: "penultimate" + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0. + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 10000 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + data: + num_workers: 16 + synthetic_data: False # dataset_path and local_root_path can be empty when using synthetic data + synthetic_data_length: 10000 + train: + dataset_path: + - /datasets/coyo/test.pkl + augmentations: + resize_smallest_side: 512 + center_crop_h_w: 512, 512 + horizontal_flip: False + filterings: + + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_fid_images.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_fid_images.yaml new file mode 100644 index 000000000000..e526bc52d673 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_fid_images.yaml @@ -0,0 +1,45 @@ +name: stable-diffusion-train + +fid: + classifier_free_guidance: + - 1.5 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + nnodes_per_cfg: 1 + ntasks_per_node: 8 + local_task_id: null + num_images_to_eval: 30000 + coco_captions_path: /coco2014/coco2014_val_sampled_30k/captions + coco_images_path: /coco2014/coco2014_val/images_256 + save_path: output + +infer: + unconditional_guidance_scale: null + num_images_per_prompt: 1 + height: 512 + width: 512 + down_factor: 8 + inference_steps: 50 + sampler_type: 'PLMS' + eta: 0 + output_type: 'pil' + save_to_file: False # We need to rename and maintain the order of images for clip score calculation, so we will save it outside the inference pipeline + out_path: ${fid.save_path} + seed: 123 + prompts: + +trainer: + devices: ${fid.ntasks_per_node} + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + +model: + restore_from_path: null + precision: ${trainer.precision} \ No newline at end of file diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_infer.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_infer.yaml new file mode 100644 index 000000000000..dbe384dd2566 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_infer.yaml @@ -0,0 +1,31 @@ +name: stable-diffusion-train + +infer: + unconditional_guidance_scale: 7.5 + num_images_per_prompt: 4 + height: 512 + width: 512 + down_factor: 8 + inference_steps: 25 + sampler_type: 'DPM' + eta: 0 + output_type: 'pil' + save_to_file: True + out_path: 'stable-diffusion' + seed: 123 + prompts: + - 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.' + - 'A cute corgi lives in a house made out of sushi.' + - 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.' + - 'A brain riding a rocketship heading towards the moon.' + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + +model: + restore_from_path: null + precision: ${trainer.precision} \ No newline at end of file diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml new file mode 100644 index 000000000000..6c07d460670c --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -0,0 +1,208 @@ +name: stable-diffusion-train + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 2 # PTL default. In practice, max_steps will be reached first. + max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: stable-diffusion + group: nemo-sd + name: ${name} + resume: True + create_checkpoint_callback: True + create_tensorboard_logger: True + checkpoint_callback_params: + every_n_train_steps: 1000 + every_n_epochs: 0 + monitor: reduced_train_loss + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 1 # will use more micro batches to reach global batch size + native_amp_init_scale: 65536.0 # Init scale for grad scaler used at fp16 + + + linear_start: 0.00085 + linear_end: 0.012 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: images + cond_stage_key: captions # txt for cifar, caption for pbss + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn # check + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + scale_by_std: False + ckpt_path: + ignore_keys: [] + parameterization: eps + clip_denoised: True + load_only_unet: False + cosine_s: 8e-3 + given_betas: + original_elbo_weight: 0 + v_posterior: 0 + l_simple_weight: 1 + use_positional_encodings: False + learn_logvar: False + logvar_init: 0 + beta_schedule: linear + loss_type: l2 + + concat_mode: True + cond_stage_forward: + text_embedding_dropout_rate: 0.1 + fused_opt: True + inductor: False + inductor_cudagraphs: False + capture_cudagraph_iters: -1 # -1 to disable + channels_last: True + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: #/ckpts/nemo-v1-2.ckpt + from_NeMo: True #Must be specified when from pretrained is not None, False means loading unet from HF ckpt + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_heads: 8 + use_spatial_transformer: true + transformer_depth: 1 + context_dim: 768 + use_checkpoint: False + legacy: False + use_flash_attention: True + enable_amp_o2_fp16: True + resblock_gn_groups: 32 + + first_stage_config: + _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKL + from_pretrained: /ckpts/vae.bin + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 #Never used + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + capture_cudagraph_iters: ${model.capture_cudagraph_iters} + + cond_stage_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenMegatronCLIPEmbedder + restore_from_path: /ckpts/openai.nemo + device: cuda + freeze: True + layer: "last" + # For compatibility of history version that uses HF clip model + # _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + # version: openai/clip-vit-large-patch14 + # device: cuda + # max_length: 77 + # capture_cudagraph_iters: ${model.capture_cudagraph_iters} + + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: True # True for using PyTorch DDP overlap. + + optim: + name: megatron_fused_adam + lr: null + weight_decay: 0. + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 10000 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + capturable: True + master_weights: True + max_norm: ${trainer.gradient_clip_val} + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + data: + num_workers: 16 + synthetic_data: False # dataset_path and local_root_path can be empty when using synthetic data + synthetic_data_length: 10000 + train: + dataset_path: + - /datasets/coyo/wdinfo.pkl + augmentations: + resize_smallest_side: 512 + center_crop_h_w: 512, 512 + horizontal_flip: False + filterings: + + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo diff --git a/examples/multimodal/text_to_image/stable_diffusion/generate_fid_images.py b/examples/multimodal/text_to_image/stable_diffusion/generate_fid_images.py new file mode 100644 index 000000000000..d04a1d2b18af --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/generate_fid_images.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from omegaconf.omegaconf import open_dict + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='sd_fid_images') +def main(cfg): + # Read configuration parameters + nnodes_per_cfg = cfg.fid.nnodes_per_cfg + ntasks_per_node = cfg.fid.ntasks_per_node + local_task_id = cfg.fid.local_task_id + num_images_to_eval = cfg.fid.num_images_to_eval + path = cfg.fid.coco_captions_path + + node_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", 0)) + node_id_per_cfg = node_id % nnodes_per_cfg + + current_node_cfg = cfg.fid.classifier_free_guidance[node_id // nnodes_per_cfg] + with open_dict(cfg): + cfg.infer.unconditional_guidance_scale = current_node_cfg + save_path = os.path.join(cfg.fid.save_path, str(current_node_cfg)) + + # Read and store captions + captions = [] + caption_files = sorted(os.listdir(path)) + assert len(caption_files) >= num_images_to_eval + for file in caption_files[:num_images_to_eval]: + with open(os.path.join(path, file), 'r') as f: + captions += f.readlines() + + # Calculate partition sizes and select the partition for the current node + partition_size_per_node = num_images_to_eval // nnodes_per_cfg + start_idx = node_id_per_cfg * partition_size_per_node + end_idx = (node_id_per_cfg + 1) * partition_size_per_node if node_id_per_cfg != nnodes_per_cfg - 1 else None + captions = captions[start_idx:end_idx] + + local_task_id = int(local_task_id) if local_task_id is not None else int(os.environ.get("SLURM_LOCALID", 0)) + partition_size_per_task = int(len(captions) // ntasks_per_node) + + # Select the partition for the current task + start_idx = local_task_id * partition_size_per_task + end_idx = (local_task_id + 1) * partition_size_per_task if local_task_id != ntasks_per_node - 1 else None + input = captions[start_idx:end_idx] + + print(f"Current worker {node_id}:{local_task_id} will generate {len(input)} images") + + os.makedirs(save_path, exist_ok=True) + + # Modify the model configuration + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + model_cfg.unet_config.use_flash_attention = False + model_cfg.unet_config.from_pretrained = None + model_cfg.first_stage_config.from_pretrained = None + model_cfg.global_batch_size = model_cfg.micro_batch_size * ntasks_per_node + + # Set up the trainer and model for inference + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronLatentDiffusion, cfg=cfg, model_cfg_modifier=model_cfg_modifier + ) + model = megatron_diffusion_model.model + model.cuda().eval() + + # Generate images using the model and save them + for i, prompt in enumerate(input): + cfg.infer.prompts = [prompt] + rng = torch.Generator().manual_seed(cfg.infer.seed + local_task_id * 10 + node_id_per_cfg * 100 + i * 1000) + output = pipeline(model, cfg, rng=rng) + for image in output[0]: + image_num = i + partition_size_per_node * node_id_per_cfg + partition_size_per_task * local_task_id + image.save(os.path.join(save_path, f'image{image_num:06d}.png')) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py new file mode 100644 index 000000000000..f1e5e2872ea7 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_infer.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline +from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference +from nemo.core.config import hydra_runner + + +@hydra_runner(config_path='conf', config_name='sd_infer') +def main(cfg): + def model_cfg_modifier(model_cfg): + model_cfg.precision = cfg.trainer.precision + model_cfg.ckpt_path = None + model_cfg.inductor = False + model_cfg.unet_config.use_flash_attention = False + model_cfg.unet_config.from_pretrained = None + model_cfg.first_stage_config.from_pretrained = None + + torch.backends.cuda.matmul.allow_tf32 = True + trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( + model_provider=MegatronLatentDiffusion, cfg=cfg, model_cfg_modifier=model_cfg_modifier + ) + model = megatron_diffusion_model.model + model.cuda().eval() + + rng = torch.Generator().manual_seed(cfg.infer.seed) + pipeline(model, cfg, rng=rng) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py new file mode 100644 index 000000000000..9259b4960734 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +class MegatronStableDiffusionTrainerBuilder(MegatronTrainerBuilder): + """Builder for SD model Trainer with overrides.""" + + def _training_strategy(self) -> NLPDDPStrategy: + """ + Returns a ddp strategy passed to Trainer.strategy. + """ + ddp_overlap = self.cfg.model.get('ddp_overlap', True) + if ddp_overlap: + return NLPDDPStrategy( + no_ddp_communication_hook=False, + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=True, + bucket_cap_mb=256, + ) + else: + return NLPDDPStrategy( + no_ddp_communication_hook=True, + gradient_as_bucket_view=self.cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + + +@hydra_runner(config_path='conf', config_name='sd_train') +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + torch.backends.cuda.matmul.allow_tf32 = True + + if cfg.model.capture_cudagraph_iters >= 0: + # Required by CUDA graph with DDP + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + # Hack to avoid CUDA graph issue with AMP, PyTorch Lightning doesn't support + # changing autocast arguments for now. + # https://github.com/pytorch/pytorch/blob/v1.13.1/torch/cuda/graphs.py#L234 + def amp_autocast_init(self, *args, **kwargs): + if "cache_enabled" not in kwargs: + kwargs["cache_enabled"] = False + return self.__orig_init__(*args, **kwargs) + + torch.cuda.amp.autocast.__orig_init__ = torch.cuda.amp.autocast.__init__ + torch.cuda.amp.autocast.__init__ = amp_autocast_init + torch.autocast.__orig_init__ = torch.autocast.__init__ + torch.autocast.__init__ = amp_autocast_init + + trainer = MegatronStableDiffusionTrainerBuilder(cfg).create_trainer() + + exp_manager(trainer, cfg.exp_manager) + + model = MegatronLatentDiffusion(cfg.model, trainer) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index f487667922f5..4ac99a951f0d 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -38,7 +38,7 @@ from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPModel -from nemo.collections.multimodal.models.vision_language_foundation.clip import MegatronCLIPModel +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.utils import AppState, logging from nemo.utils.distributed import initialize_distributed diff --git a/examples/multimodal/vision_language_foundation/clip/megatron_clip_imagenet_zeroshot.py b/examples/multimodal/vision_language_foundation/clip/megatron_clip_imagenet_zeroshot.py index bb1d659dc3c3..ae481cf13545 100644 --- a/examples/multimodal/vision_language_foundation/clip/megatron_clip_imagenet_zeroshot.py +++ b/examples/multimodal/vision_language_foundation/clip/megatron_clip_imagenet_zeroshot.py @@ -18,7 +18,7 @@ from tqdm import tqdm from nemo.collections.multimodal.data.clip.clip_dataset import build_imagenet_validation_dataloader -from nemo.collections.multimodal.models.vision_language_foundation.clip import MegatronCLIPModel +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision diff --git a/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py b/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py index 5bc7655cfd0c..d77802e5a010 100644 --- a/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py +++ b/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py @@ -17,7 +17,7 @@ from PIL import Image from nemo.collections.multimodal.data.clip.clip_dataset import get_preprocess_fns -from nemo.collections.multimodal.models.vision_language_foundation.clip import MegatronCLIPModel +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.core.config import hydra_runner diff --git a/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py b/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py index 605469abd2b3..cc2f13df8d0f 100644 --- a/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py +++ b/examples/multimodal/vision_language_foundation/clip/megatron_clip_pretrain.py @@ -15,7 +15,7 @@ from omegaconf.omegaconf import OmegaConf -from nemo.collections.multimodal.models.vision_language_foundation.clip import MegatronCLIPModel +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.core.config import hydra_runner from nemo.utils import logging diff --git a/examples/multimodal/x_to_nerf/benchmark_callback.py b/examples/multimodal/x_to_nerf/benchmark_callback.py new file mode 100644 index 000000000000..fd7d5afdc5bc --- /dev/null +++ b/examples/multimodal/x_to_nerf/benchmark_callback.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Optional + +from pytorch_lightning import Callback, LightningModule, Trainer + +from nemo.utils import logging + + +class BenchmarkCallback(Callback): + def __init__( + self, + start_benchmark_at_step: int = 0, + stop_benchmark_at_step: Optional[int] = None, + log_every_n_steps: int = 10, + ): + super().__init__() + self.start_benchmark_at_step = start_benchmark_at_step + self.stop_benchmark_at_step = stop_benchmark_at_step + self.log_every_n_steps = log_every_n_steps + self.train_times = [] + self.val_times = [] + self.train_steps_times = [] + self.val_steps_times = [] + + def should_benchmark(self, trainer: Trainer): + if self.stop_benchmark_at_step is None: + return trainer.global_step >= self.start_benchmark_at_step + return self.start_benchmark_at_step <= trainer.global_step <= self.stop_benchmark_at_step + + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + self.epoch_start_time = time.time() + + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if self.should_benchmark(trainer): + epoch_time = time.time() - self.epoch_start_time + self.train_times.append(epoch_time) + logging.info(f'Training-Epoch-{trainer.current_epoch}-Time: {epoch_time} [sec]') + + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int): + self.step_start_time = time.time() + + def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int): + if self.should_benchmark(trainer): + step_time = time.time() - self.step_start_time + self.train_steps_times.append(step_time) + if trainer.global_step % self.log_every_n_steps == 0: + logging.info(f'Training-Step-{trainer.global_step}-Time: {step_time} [sec]') + + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule): + self.val_start_time = time.time() + + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule): + if self.should_benchmark(trainer): + val_time = time.time() - self.val_start_time + self.val_times.append(val_time) + logging.info(f'Validation-Epoch-{trainer.current_epoch}-Time: {val_time} [sec]') + + def on_validation_batch_start( + self, trainer: Trainer, pl_module: LightningModule, batch, batch_idx: int, dataloader_idx: int + ): + self.val_step_start_time = time.time() + + def on_validation_batch_end( + self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx: int, dataloader_idx: int + ): + if self.should_benchmark(trainer): + val_step_time = time.time() - self.val_step_start_time + self.val_steps_times.append(val_step_time) + if trainer.global_step % self.log_every_n_steps == 0: + logging.info(f'Validation-Step-{trainer.global_step}-Time: {val_step_time} [sec]') + + def on_fit_end(self, trainer: Trainer, pl_module: LightningModule): + if self.should_benchmark(trainer): + avg_train_time = sum(self.train_times) / len(self.train_times) + avg_val_time = sum(self.val_times) / len(self.val_times) + avg_train_step_time = sum(self.train_steps_times) / len(self.train_steps_times) + avg_val_step_time = sum(self.val_steps_times) / len(self.val_steps_times) + + logging.info(f'Average-Training-Epoch-Time: {avg_train_time} [sec]') + logging.info(f'Average-Validation-Epoch-Time: {avg_val_time} [sec]') + logging.info(f'Average-Training-Step-Time: {avg_train_step_time} [sec]') + logging.info(f'Average-Validation-Step-Time: {avg_val_step_time} [sec]') diff --git a/examples/multimodal/x_to_nerf/config/config.yaml b/examples/multimodal/x_to_nerf/config/config.yaml new file mode 100644 index 000000000000..1adcbae72c26 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/config.yaml @@ -0,0 +1,52 @@ +defaults: + - model: dreamfusion + - _self_ + +name: DreamFusion +seed: 2023 +mode: fit # fit, validate, test, export-mesh + +# export-mesh options +mesh_fname: /results/mesh.obj # mesh file name when mode=export-mesh +mesh_resolution: 128 # Mesh resolution when mode=export-mesh + +# benchmark options +enable_benchmark: False +benchmark_callback: + _target_: benchmark_callback.BenchmarkCallback + log_every_n_steps: 1 + +trainer: + devices: 1 + num_nodes: 1 + precision: 16 + max_steps: 10000 # example configs: dreamfuions=10000, dmtet=5000 + accelerator: gpu + enable_checkpointing: False + logger: False + log_every_n_steps: 1 + val_check_interval: 100 + accumulate_grad_batches: 1 + benchmark: False + enable_model_summary: True + +exp_manager: + name: ${name} + exp_dir: /results + create_tensorboard_logger: False + create_wandb_logger: False + wandb_logger_kwargs: + project: dreamfusion + group: nemo-df + name: ${name} + resume: True + create_checkpoint_callback: True + checkpoint_callback_params: + every_n_epochs: 0 + every_n_train_steps: 1000 # TODO(ahmadki): being ignored ? + monitor: loss + filename: '${name}-{step}' + save_top_k: -1 + always_save_nemo: False + resume_if_exists: True + resume_ignore_no_checkpoint: True diff --git a/examples/multimodal/x_to_nerf/config/model/background/random.yaml b/examples/multimodal/x_to_nerf/config/model/background/random.yaml new file mode 100644 index 000000000000..9cfb09fc6eca --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/background/random.yaml @@ -0,0 +1,3 @@ +_target_: nemo.collections.multimodal.modules.nerf.background.random_background.RandomBackground +base_background: [1, 1, 1] +random_ratio: 0.5 diff --git a/examples/multimodal/x_to_nerf/config/model/background/static.yaml b/examples/multimodal/x_to_nerf/config/model/background/static.yaml new file mode 100644 index 000000000000..eb82f9944991 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/background/static.yaml @@ -0,0 +1,2 @@ +_target_: nemo.collections.multimodal.modules.nerf.background.static_background.StaticBackground +background: [0, 0, 1] # rgb diff --git a/examples/multimodal/x_to_nerf/config/model/background/tcnn.yaml b/examples/multimodal/x_to_nerf/config/model/background/tcnn.yaml new file mode 100644 index 000000000000..8daf7bcd8349 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/background/tcnn.yaml @@ -0,0 +1,19 @@ +_target_: nemo.collections.multimodal.modules.nerf.background.tcnn_background.TCNNBackground +bound: 1 +encoder_num_input_dims: 3 # 3 directions +encoder_cfg: + otype: "HashGrid" + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + interpolation: "Smoothstep" + per_level_scale: # default is np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) + +background_net_num_output_dims: 3 # rgb +background_net_cfg: + otype: "FullyFusedMLP" + activation: "ReLU" + output_activation: "None" + n_neurons: 32 + n_hidden_layers: 2 diff --git a/examples/multimodal/x_to_nerf/config/model/background/torchngp.yaml b/examples/multimodal/x_to_nerf/config/model/background/torchngp.yaml new file mode 100644 index 000000000000..b77778099e79 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/background/torchngp.yaml @@ -0,0 +1,11 @@ +_target_: nemo.collections.multimodal.modules.nerf.background.torchngp_background.TorchNGPBackground + +encoder_type: "frequency" +encoder_input_dims: 3 +encoder_multi_res: 6 + +num_output_dims: 3 +net_cfg: + num_hidden_dims: 32 + num_layers: 2 + bias: True diff --git a/examples/multimodal/x_to_nerf/config/model/data/data.yaml b/examples/multimodal/x_to_nerf/config/model/data/data.yaml new file mode 100644 index 000000000000..0b5f88b9f1fb --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/data/data.yaml @@ -0,0 +1,41 @@ +_target_: data.AggregatorDataModule + +train_batch_size: 1 +train_shuffle: false +train_dataset: + _target_: nemo.collections.multimodal.data.nerf.random_poses.RandomPosesDataset + internal_batch_size: 100 + width: 64 + height: 64 + radius_range: [3.0, 3.5] + theta_range: [45, 105] + phi_range: [-180, 180] + fovx_range: [10, 30] + fovy_range: [10, 30] + jitter: False + jitter_center: 0.2 + jitter_target: 0.2 + jitter_up: 0.02 + uniform_sphere_rate: 0 + angle_overhead: 30 + angle_front: 60 + +val_batch_size: 1 +val_shuffle: false +val_dataset: + _target_: nemo.collections.multimodal.data.nerf.circle_poses.CirclePosesDataset + size: 5 + width: 800 + height: 800 + angle_overhead: 30 + angle_front: 60 + +test_batch_size: 1 +test_shuffle: false +test_dataset: + _target_: nemo.collections.multimodal.data.nerf.circle_poses.CirclePosesDataset + size: 100 + width: 800 + height: 800 + angle_overhead: 30 + angle_front: 60 diff --git a/examples/multimodal/x_to_nerf/config/model/dreamfusion-dmtet.yaml b/examples/multimodal/x_to_nerf/config/model/dreamfusion-dmtet.yaml new file mode 100644 index 000000000000..bfadd4f426b3 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/dreamfusion-dmtet.yaml @@ -0,0 +1,40 @@ +_target_: nemo.collections.multimodal.models.nerf.dreamfusion.DreamFusion # TODO(ahmadki): dreamfusion-dmetet should have it's own class +defaults: + - nerf: torchngp + - background: torchngp + - material: basic_shading + - renderer: nvdiffrast + - guidance: sd_huggingface + - optim: adan + - loss: dmtet + - data: data + - _self_ + +### model options +resume_from_checkpoint: +prompt: 'a hamburger' +negative_prompt: '' +front_prompt: ', front view' +side_prompt: ', side view' +back_prompt: ', back view' +update_extra_interval: 16 +guidance_scale: 100 +export_video: False + +iters: ${trainer.max_steps} +# TODO(ahmadki): move to database +latent_iter_ratio: 0.0 +albedo_iter_ratio: 0 +min_ambient_ratio: 0.1 +textureless_ratio: 0.2 + +data: + train_dataset: + width: 512 + height: 512 + val_dataset: + width: 800 + height: 800 + test_dataset: + width: 800 + height: 800 diff --git a/examples/multimodal/x_to_nerf/config/model/dreamfusion.yaml b/examples/multimodal/x_to_nerf/config/model/dreamfusion.yaml new file mode 100644 index 000000000000..a67393341b53 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/dreamfusion.yaml @@ -0,0 +1,40 @@ +_target_: nemo.collections.multimodal.models.nerf.dreamfusion.DreamFusion +defaults: + - nerf: torchngp + - background: static + - material: basic_shading + - renderer: torchngp_raymarching + - guidance: sd_huggingface + - optim: adan + - loss: dreamfusion + - data: data + - _self_ + +### model options +resume_from_checkpoint: +prompt: 'a hamburger' +negative_prompt: '' +front_prompt: ', front view' +side_prompt: ', side view' +back_prompt: ', back view' +update_extra_interval: 16 +guidance_scale: 100 +export_video: False + +iters: ${trainer.max_steps} +# TODO(ahmadki): move to database +latent_iter_ratio: 0.2 +albedo_iter_ratio: 0.0 +min_ambient_ratio: 0.1 +textureless_ratio: 0.2 + +data: + train_dataset: + width: 64 + height: 64 + val_dataset: + width: 800 + height: 800 + test_dataset: + width: 800 + height: 800 diff --git a/examples/multimodal/x_to_nerf/config/model/guidance/sd_huggingface.yaml b/examples/multimodal/x_to_nerf/config/model/guidance/sd_huggingface.yaml new file mode 100644 index 000000000000..a8b7adca3c55 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/guidance/sd_huggingface.yaml @@ -0,0 +1,4 @@ +_target_: nemo.collections.multimodal.modules.nerf.guidance.stablediffusion_huggingface_pipeline.StableDiffusion +precision: ${trainer.precision} +model_key: stabilityai/stable-diffusion-2-1-base +t_range: [0.02, 0.98] diff --git a/examples/multimodal/x_to_nerf/config/model/guidance/sd_nemo.yaml b/examples/multimodal/x_to_nerf/config/model/guidance/sd_nemo.yaml new file mode 100644 index 000000000000..fd4517ec1f7c --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/guidance/sd_nemo.yaml @@ -0,0 +1,4 @@ +_target_: nemo.collections.multimodal.modules.nerf.guidance.stablediffusion_nemo_pipeline.StableDiffusion +checkpoint: /sd_checkpoints/nemo-1.5/sd-1.5.nemo +sampler_type: 'DDIM' +t_range: [0.02, 0.98] diff --git a/examples/multimodal/x_to_nerf/config/model/guidance/sd_trt.yaml b/examples/multimodal/x_to_nerf/config/model/guidance/sd_trt.yaml new file mode 100644 index 000000000000..45c1e2ac8fb5 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/guidance/sd_trt.yaml @@ -0,0 +1,5 @@ +_target_: nemo.collections.multimodal.modules.nerf.guidance.stablediffusion_trt_pipeline.StableDiffusion +checkpoint: /sd_checkpoints/nemo-1.5/sd-1.5.nemo +plan_dir: /sd_checkpoints/nemo-1.5/plan +sampler_type=: DDIM" +t_range: [0.02, 0.98] diff --git a/examples/multimodal/x_to_nerf/config/model/loss/dmtet.yaml b/examples/multimodal/x_to_nerf/config/model/loss/dmtet.yaml new file mode 100644 index 000000000000..188c1034fc27 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/loss/dmtet.yaml @@ -0,0 +1,8 @@ +lambda_sds: 1.0 +lambda_opacity: 0.0 +lambda_entropy: 0.0 +lambda_orientation: 0.0 +lambda_2d_normal_smooth: 0.0 +lambda_3d_normal_smooth: 0.0 +lambda_mesh_normal: 0.5 +lambda_mesh_laplacian: 0.5 diff --git a/examples/multimodal/x_to_nerf/config/model/loss/dreamfusion.yaml b/examples/multimodal/x_to_nerf/config/model/loss/dreamfusion.yaml new file mode 100644 index 000000000000..8cfd4b47eb51 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/loss/dreamfusion.yaml @@ -0,0 +1,8 @@ +lambda_sds: 1.0 +lambda_opacity: 0.0 +lambda_entropy: 1e-3 +lambda_orientation: 1e-2 +lambda_2d_normal_smooth: 0.0 +lambda_3d_normal_smooth: 0.0 +lambda_mesh_normal: 0.0 +lambda_mesh_laplacian: 0.0 diff --git a/examples/multimodal/x_to_nerf/config/model/material/basic_shading.yaml b/examples/multimodal/x_to_nerf/config/model/material/basic_shading.yaml new file mode 100644 index 000000000000..802defad1637 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/material/basic_shading.yaml @@ -0,0 +1 @@ +_target_: nemo.collections.multimodal.modules.nerf.materials.basic_shading.BasicShading diff --git a/examples/multimodal/x_to_nerf/config/model/nerf/tcnn.yaml b/examples/multimodal/x_to_nerf/config/model/nerf/tcnn.yaml new file mode 100644 index 000000000000..0bf5ed6c5e2f --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/nerf/tcnn.yaml @@ -0,0 +1,32 @@ +_target_: nemo.collections.multimodal.modules.nerf.geometry.tcnn_nerf.TCNNNerf +num_input_dims: 3 # 3D space +bound: 1 +density_activation: softplus # softplus, exp +blob_radius: 0.5 +blob_density: 10 +normal_type: central_finite_difference + +encoder_cfg: + otype: "HashGrid" + n_levels: 16 + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 16 + interpolation: "Smoothstep" + per_level_scale: # default is np.exp2(np.log2(2048 * bound / 16) / (16 - 1)) + +sigma_net_num_output_dims: 1 # density +sigma_net_cfg: + otype: "FullyFusedMLP" + activation: "ReLU" + output_activation: "None" + n_neurons: 64 + n_hidden_layers: 3 + +features_net_num_output_dims: 3 # rgb +features_net_cfg: + otype: "FullyFusedMLP" + activation: "ReLU" + output_activation: "None" + n_neurons: 64 + n_hidden_layers: 3 diff --git a/examples/multimodal/x_to_nerf/config/model/nerf/torchngp.yaml b/examples/multimodal/x_to_nerf/config/model/nerf/torchngp.yaml new file mode 100644 index 000000000000..48877dcfa871 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/nerf/torchngp.yaml @@ -0,0 +1,26 @@ +_target_: nemo.collections.multimodal.modules.nerf.geometry.torchngp_nerf.TorchNGPNerf +num_input_dims: 3 # 3D space +bound: 1 +density_activation: exp # softplus, exp +blob_radius: 0.2 +blob_density: 5 +normal_type: central_finite_difference + +encoder_cfg: + encoder_type: 'hashgrid' + encoder_max_level: + log2_hashmap_size: 19 + desired_resolution: 2048 + interpolation: smoothstep + +sigma_net_num_output_dims: 1 # density +sigma_net_cfg: + num_hidden_dims: 64 + num_layers: 3 + bias: True + +features_net_num_output_dims: 3 # rgb +features_net_cfg: + num_hidden_dims: 64 + num_layers: 3 + bias: True diff --git a/examples/multimodal/x_to_nerf/config/model/optim/adan.yaml b/examples/multimodal/x_to_nerf/config/model/optim/adan.yaml new file mode 100644 index 000000000000..885c13fcca8a --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/optim/adan.yaml @@ -0,0 +1,6 @@ +name: adan +lr: 5e-3 +eps: 1e-8 +weight_decay: 2e-5 +max_grad_norm: 5.0 +foreach: False diff --git a/examples/multimodal/x_to_nerf/config/model/renderer/nerfacc.yaml b/examples/multimodal/x_to_nerf/config/model/renderer/nerfacc.yaml new file mode 100644 index 000000000000..73f48a7a0ca9 --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/renderer/nerfacc.yaml @@ -0,0 +1,8 @@ +_target_: nemo.collections.multimodal.modules.nerf.renderers.nerfacc_volume_renderer.NerfaccVolumeBaseRenderer +grid_resolution: 128 +grid_levels: 3 +bound: ${model.nerf.bound} +render_step_size: 1.e-3 +near_plane: 0.2 +cone_angle: 0.004 +alpha_thre: 1.e-2 diff --git a/examples/multimodal/x_to_nerf/config/model/renderer/nvdiffrast.yaml b/examples/multimodal/x_to_nerf/config/model/renderer/nvdiffrast.yaml new file mode 100644 index 000000000000..fefc217f4aec --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/renderer/nvdiffrast.yaml @@ -0,0 +1,6 @@ +_target_: nemo.collections.multimodal.modules.nerf.renderers.nvdiffrast_renderer.NVDiffRastRenderer +bound: ${model.nerf.bound} +grid_resolution: 128 +density_thresh: 10.0 +update_interval: 16 +quartet_file: "/results/tets/128_tets.npz" diff --git a/examples/multimodal/x_to_nerf/config/model/renderer/torchngp_raymarching.yaml b/examples/multimodal/x_to_nerf/config/model/renderer/torchngp_raymarching.yaml new file mode 100644 index 000000000000..5075a5fbc85c --- /dev/null +++ b/examples/multimodal/x_to_nerf/config/model/renderer/torchngp_raymarching.yaml @@ -0,0 +1,7 @@ +_target_: nemo.collections.multimodal.modules.nerf.renderers.torchngp_volume_renderer.TorchNGPVolumeRenderer +bound: ${model.nerf.bound} +update_interval: 16 +grid_resolution: 128 +density_thresh: 10 +max_steps: 1024 +dt_gamma: 0 diff --git a/examples/multimodal/x_to_nerf/data.py b/examples/multimodal/x_to_nerf/data.py new file mode 100644 index 000000000000..fe7c47abc64b --- /dev/null +++ b/examples/multimodal/x_to_nerf/data.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from omegaconf.omegaconf import DictConfig +from torch.utils.data import DataLoader + + +# TODO(ahmadki): multi-GPU needs more work, we currently don't shard data +# across GPUs, which is OK for trainnig, but needs fixing for validation and testing. +class AggregatorDataModule(pl.LightningDataModule): + def __init__( + self, + train_dataset: DictConfig = None, + train_batch_size: int = 1, + train_shuffle: bool = False, + val_dataset: DictConfig = None, + val_batch_size: int = 1, + val_shuffle: bool = False, + test_dataset: DictConfig = None, + test_batch_size: int = 1, + test_shuffle: bool = False, + ): + super().__init__() + + self.train_dataset = train_dataset + self.train_batch_size = train_batch_size + self.train_shuffle = train_shuffle + self.val_dataset = val_dataset + self.val_batch_size = val_batch_size + self.val_shuffle = val_shuffle + self.test_dataset = test_dataset + self.test_batch_size = test_batch_size + self.test_shuffle = test_shuffle + + # TODO(ahmadki): lazy init + # def setup(self, stage=None) -> None: + # if stage in [None, "fit"]: + # self.train_dataset = instantiate(self.train_dataset) + # if stage in [None, "fit", "validate"]: + # self.val_dataset = instantiate(self.val_dataset) + # if stage in [None, "test", "predict"]: + # self.test_dataset = instantiate(self.test_dataset) + + def train_dataloader(self) -> DataLoader: + loader = DataLoader( + self.train_dataset, + batch_size=self.train_batch_size, + collate_fn=self.train_dataset.collate_fn, + pin_memory=True, + num_workers=4, + ) + return loader + + def val_dataloader(self) -> DataLoader: + loader = DataLoader( + self.val_dataset, + batch_size=self.val_batch_size, + collate_fn=self.val_dataset.collate_fn, + shuffle=self.val_shuffle, + pin_memory=True, + num_workers=0, + ) + return loader + + def test_dataloader(self) -> DataLoader: + loader = DataLoader( + self.test_dataset, + batch_size=self.test_batch_size, + collate_fn=self.test_dataset.collate_fn, + shuffle=self.test_shuffle, + pin_memory=True, + num_workers=0, + ) + return loader diff --git a/examples/multimodal/x_to_nerf/main.py b/examples/multimodal/x_to_nerf/main.py new file mode 100644 index 000000000000..5d7f616a3165 --- /dev/null +++ b/examples/multimodal/x_to_nerf/main.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hydra.utils import get_class, instantiate +from omegaconf.omegaconf import DictConfig, OmegaConf +from pytorch_lightning import Trainer, seed_everything + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path='config', config_name='config') +def main(cfg: DictConfig) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + seed_everything(cfg.seed, workers=True) + + mode = cfg.mode + logging.info(f"{mode=}") + + model = None + model_cls = get_class(cfg.model._target_) + if cfg.model.resume_from_checkpoint is None: + model = model_cls(cfg=cfg.model) + else: + logging.info(f"Loading model from checkpoint: {cfg.model.resume_from_checkpoint}") + model = model_cls.load_from_checkpoint(cfg.model.resume_from_checkpoint, strict=False, cfg=cfg.model) + + if mode == "export-mesh": + mesh = model.mesh(resolution=cfg.mesh_resolution) + mesh.export(cfg.mesh_fname) + return + + # Prepare callbacks + callbacks = [] + if cfg.enable_benchmark: + callbacks.append(instantiate(cfg.benchmark_callback)) + + # Setup trainer + trainer = Trainer(callbacks=callbacks, **cfg.trainer) + exp_manager(trainer, cfg.exp_manager) + + # Setup datamodule + dm = instantiate(cfg.model.data) + + if mode == "fit": + trainer.fit(model, datamodule=dm) + elif mode == "validate": + trainer.validate(model, datamodule=dm) + elif mode == "test": + trainer.test(model, datamodule=dm) + else: + raise ValueError(f"Invalid mode: {mode}") + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/multimodal/data/clip/augmentations/augmentations.py b/nemo/collections/multimodal/data/clip/augmentations/augmentations.py index de728edc41c3..a0a96d39de04 100644 --- a/nemo/collections/multimodal/data/clip/augmentations/augmentations.py +++ b/nemo/collections/multimodal/data/clip/augmentations/augmentations.py @@ -20,16 +20,22 @@ import torch import torch.nn as nn -import torchvision.transforms.functional as F -from torchvision.transforms import ( - CenterCrop, - Compose, - InterpolationMode, - Normalize, - RandomResizedCrop, - Resize, - ToTensor, -) + +try: + import torchvision.transforms.functional as F + from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomResizedCrop, + Resize, + ToTensor, + ) + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) @@ -52,6 +58,7 @@ def forward(self, img): width, height = img.size scale = self.max_size / float(max(height, width)) if scale != 1.0: + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." new_size = tuple(round(dim * scale) for dim in (height, width)) img = F.resize(img, new_size, self.interpolation) pad_h = self.max_size - new_size[0] @@ -72,6 +79,7 @@ def image_transform( resize_longest_max: bool = False, fill_color: int = 0, ): + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." mean = mean or OPENAI_DATASET_MEAN if not isinstance(mean, (list, tuple)): mean = (mean,) * 3 diff --git a/nemo/collections/multimodal/data/controlnet/__init__.py b/nemo/collections/multimodal/data/controlnet/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/controlnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/controlnet/controlnet_dataset.py b/nemo/collections/multimodal/data/controlnet/controlnet_dataset.py new file mode 100644 index 000000000000..3bf7b76709e5 --- /dev/null +++ b/nemo/collections/multimodal/data/controlnet/controlnet_dataset.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.data.common.webdataset import WebDatasetCommon +from nemo.collections.multimodal.data.stable_diffusion.augmentation.augmentations import ( + construct_image_augmentations, + identical_transform, +) +from nemo.core.classes import Dataset as NeMoDataset + + +class ControlNetSyntheticDataset(NeMoDataset): + def __init__( + self, + image_H, + image_W, + fake_len=100000, + image_key='images', + txt_key='txt', + control_key='hint', + seq_len=80, + context_dim=768, + ): + super().__init__() + self.fake_len = fake_len + self.H = image_H + self.W = image_W + self.image_key = image_key + self.txt_key = txt_key + self.control_key = control_key + self.seq_len = seq_len + self.context_dim = context_dim + + def __getitem__(self, index): + item = {} + item[self.image_key] = torch.randn(self.H, self.W, 3) + item[self.txt_key] = f'This is meaningless fake text No.{index}' + item[self.control_key] = torch.randn(self.H, self.W, 3) + return item + + def __len__(self): + return self.fake_len + + +def build_train_valid_datasets( + model_cfg, consumed_samples, +): + data_cfg = model_cfg.data + + # This function maps data that are tuples to dictionary. + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict['images'] = input[0].permute(1, 2, 0) + out_dict['captions'] = input[1] + out_dict['hint'] = input[2].permute(1, 2, 0) + yield out_dict + + def transform_fn(sample): + + image, text, hint = sample["jpg"], sample["txt"], sample["png"] + # TODO : If no agumentations just return the image ? + img_transform = construct_image_augmentations(data_cfg.train.get("augmentations", None)) + text_transform = identical_transform + return img_transform(image), text_transform(text), img_transform(hint) + + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + train_data = ControlNetSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + control_key=model_cfg.control_key, + context_dim=model_cfg.unet_config.context_dim, + fake_len=data_cfg.synthetic_data_length, + ) + else: + train_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=True, + ) + + val_data = None + if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): + val_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=False, + ) + return train_data, val_data + + +def build_train_valid_precached_datasets( + model_cfg, consumed_samples, +): + data_cfg = model_cfg.data + + # This function maps data that are tuples to dictionary. + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict[model_cfg.first_stage_key] = torch.tensor(input['autoencoderkl_image']) + out_dict[model_cfg.cond_stage_key] = torch.tensor(input['clip-vit-large-patch14_text']) + yield out_dict + + def transform_fn(sample): + return sample['pickle'] + + train_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=True, + ) + + val_data = None + if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): + val_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=False, + ) + + return train_data, val_data diff --git a/nemo/collections/multimodal/data/dreambooth/__init__.py b/nemo/collections/multimodal/data/dreambooth/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/dreambooth/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py new file mode 100644 index 000000000000..1c39b1a72216 --- /dev/null +++ b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py @@ -0,0 +1,164 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import torch +from PIL import Image +from pytorch_lightning.utilities import rank_zero_only +from torch.utils.data import Dataset +from tqdm import tqdm + +try: + from torchvision import transforms + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + + :param instance_data_root: required, a directory with images files of the object + :param instance_prompt: captions with special token associated with instance images + :param with_prior_preservation: whether to regularize the model finetuning with the original inference output from the backbone + :param reg_data_root: a directory to save inference images from the backbone + :param reg_prompt: prompt used to generate regularization images + :param size: resizing images for training data pipeline + :param center_crop: whether performing center cropping on input images + :param load_cache_latents: when set to True, images will be converted to cached latents which will be directly loaded for training + :param vae: vae instance to encode imamges from pixel space to latent space + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + with_prior_preservation=False, + reg_data_root=None, + reg_prompt=None, + size=512, + center_crop=True, + repeat=10000, + load_cache_latents=False, + cached_instance_data_root=None, + cached_reg_data_root=None, + vae=None, + text_encoder=None, + ): + self.size = size + self.center_crop = center_crop + + assert instance_data_root or cached_instance_data_root, "must provide instance images to start training." + self.instance_data_root = Path(instance_data_root) + self.cached_instance_data_root = cached_instance_data_root + self.cached_reg_data_root = cached_reg_data_root + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images * repeat + self.load_cache_latents = load_cache_latents + self.with_prior_preservation = with_prior_preservation + + if reg_data_root is not None: + self.reg_data_root = Path(reg_data_root) + self.reg_images_path = list(self.reg_data_root.iterdir()) + self.num_reg_images = len(self.reg_images_path) + self.reg_prompt = reg_prompt + else: + self.reg_data_root = None + + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + if self.load_cache_latents: + if (self.cached_instance_data_root is None) or ( + self.with_prior_preservation and self.cached_reg_data_root is None + ): + self.cache_latents(vae, text_encoder) + + self.cached_instance_data_root = f'{self.instance_data_root}_cached' + self.cached_reg_data_root = f'{self.reg_data_root}_cached' + self.instance_images_path = list(Path(self.cached_instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + + if self.with_prior_preservation: + self.reg_images_path = list(Path(self.cached_reg_data_root).iterdir()) + self.num_reg_images = len(self.reg_images_path) + + if self.cached_instance_data_root: + self.instance_images_path = list(Path(self.cached_instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + if self.with_prior_preservation and self.cached_reg_data_root: + self.reg_images_path = list(Path(self.cached_reg_data_root).iterdir()) + self.num_reg_images = len(self.reg_images_path) + + def __len__(self): + return self._length + + def get_image(self, path): + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = self.image_transforms(image) + return image + + def __getitem__(self, index): + example = {} + if self.load_cache_latents: + example["instance_images"] = torch.load(self.instance_images_path[index % self.num_instance_images]) + else: + example["instance_images"] = self.get_image(self.instance_images_path[index % self.num_instance_images]) + example["instance_prompt"] = self.instance_prompt + + if self.reg_data_root: + if self.load_cache_latents: + example["reg_images"] = torch.load(self.reg_images_path[index % self.num_reg_images]) + else: + example["reg_images"] = self.get_image(self.reg_images_path[index % self.num_reg_images]) + example["reg_prompt"] = self.reg_prompt + + return example + + @rank_zero_only + def cache_latents(self, vae, text_encoder): + os.makedirs(f'{self.instance_data_root}_cached', exist_ok=True) + self.cached_instance_data_root = f'{self.instance_data_root}_cached' + self.cached_reg_data_root = f'{self.reg_data_root}_cached' + if self.instance_data_root and (len(os.listdir(self.cached_instance_data_root)) < self.num_instance_images): + for i in tqdm(range(self.num_instance_images)): + x = torch.Tensor(self.get_image(self.instance_images_path[i % self.num_instance_images])) + x = torch.unsqueeze(x, dim=0) + params = vae.encode(x).parameters.squeeze(dim=0) + torch.save(params, f'{self.instance_data_root}_cached/instance_image_cache_{i}.pt') + + if self.with_prior_preservation: + os.makedirs(f'{self.reg_data_root}_cached', exist_ok=True) + if self.reg_data_root and (len(os.listdir(self.cached_reg_data_root)) < self.num_reg_images): + for i in tqdm(range(self.num_reg_images)): + x = torch.Tensor(self.get_image(self.reg_images_path[i % self.num_reg_images])) + x = torch.unsqueeze(x, dim=0) + params = vae.encode(x).parameters.squeeze(dim=0) + torch.save(params, f'{self.reg_data_root}_cached/reg_image_cache_{i}.pt') diff --git a/nemo/collections/multimodal/data/imagen/__init__.py b/nemo/collections/multimodal/data/imagen/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/imagen/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/imagen/augmentations/__init__.py b/nemo/collections/multimodal/data/imagen/augmentations/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/imagen/augmentations/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/imagen/augmentations/augmentations.py b/nemo/collections/multimodal/data/imagen/augmentations/augmentations.py new file mode 100644 index 000000000000..23f481bc8720 --- /dev/null +++ b/nemo/collections/multimodal/data/imagen/augmentations/augmentations.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import torch + +from nemo.utils import logging + + +def build_resolution_filter(value=None, method='larger', image_idx=0): + """ + Filter image based on its resolution. + value: filter threshold + method: Either larger or smaller + image_idx: idx of the image in the tuple input + """ + assert method == 'larger' or method == 'smaller' + if method == 'larger': + logging.info(f'Only Selecting images with resolution >= {value}') + return lambda x: x[image_idx].size[0] >= value and x[image_idx].size[1] >= value + + logging.info(f'Only Selecting images with resolution <= {value}') + return lambda x: x[image_idx].size[0] <= value and x[image_idx].size[1] <= value + + +class PickleTransform: + """ + Convert encodings stored in the pickle file to encoding and mask. + Transform the pad and resize the embedding to match the generator config. + """ + + def __init__(self, encoding_lengths: List[int], encoding_keys: List[str], out_keys: Optional[List[str]] = None): + assert len(encoding_keys) == len(encoding_lengths) + self.encoding_lengths = encoding_lengths + self.encoding_keys = encoding_keys + self.out_keys = out_keys if out_keys is not None else encoding_keys + + def _pad_and_resize(self, arr, ntokens): + # Function for padding and resizing a numpy array + + arr = torch.tensor(arr) + embed_dim = arr.shape[1] + + arr_padded = torch.zeros(ntokens, embed_dim, device=arr.device, dtype=torch.float32) + + # If the input text is larger than num_text_tokens, clip it. + if arr.shape[0] > ntokens: + arr = arr[0:ntokens] + + mask = torch.LongTensor(ntokens).zero_() + if len(arr.shape) > 1: + mask[0 : arr.shape[0]] = 1 + + if len(arr.shape) > 1: + arr_padded[0 : arr.shape[0]] = arr + + return arr_padded, mask + + def __call__(self, data): + out_dict = dict() + for token_length, encoding_key, out_key in zip(self.encoding_lengths, self.encoding_keys, self.out_keys): + embed, mask = self._pad_and_resize(data[encoding_key]['encodings'], token_length) + out_dict[f'{out_key}_embeddings'] = embed + out_dict[f'{out_key}_mask'] = mask + return out_dict diff --git a/nemo/collections/multimodal/data/imagen/augmentations/corruption.py b/nemo/collections/multimodal/data/imagen/augmentations/corruption.py new file mode 100644 index 000000000000..2d6a25bae314 --- /dev/null +++ b/nemo/collections/multimodal/data/imagen/augmentations/corruption.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import torchvision.transforms.functional as torchvision_F + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + + +class ImagePyramidNoCorruptions: + r""" + Only downsample image without any additional corruption. + """ + + def __init__(self, target_resolutions): + self.resolutions = target_resolutions + + def obtain_image_pyramid(self, image): + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + # Downsampling + data_dict = dict() + for res in self.resolutions: + image_downsampled = torchvision_F.resize( + image, res, interpolation=torchvision_F.InterpolationMode.BICUBIC, antialias=True + ) + data_dict[f'images_{res}'] = image_downsampled + return data_dict diff --git a/nemo/collections/multimodal/data/imagen/imagen_dataset.py b/nemo/collections/multimodal/data/imagen/imagen_dataset.py new file mode 100644 index 000000000000..c3db3b3a4612 --- /dev/null +++ b/nemo/collections/multimodal/data/imagen/imagen_dataset.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.data.common.webdataset import WebDatasetCommon +from nemo.collections.multimodal.data.imagen.augmentations.augmentations import ( + PickleTransform, + build_resolution_filter, +) +from nemo.collections.multimodal.data.imagen.augmentations.corruption import ImagePyramidNoCorruptions +from nemo.collections.multimodal.data.stable_diffusion.augmentation.augmentations import ( + construct_image_augmentations, + identical_transform, +) +from nemo.core.classes import Dataset as NeMoDataset +from nemo.utils import logging + + +class ImagenSyntheticDataset(NeMoDataset): + def __init__( + self, res, conditioning_cfg, fake_len=100000, no_embedding=False, + ): + super().__init__() + self.fake_len = fake_len + self.res = res + self.no_embedding = no_embedding + if not no_embedding: + self.out_key = conditioning_cfg.out_key if conditioning_cfg.out_key else conditioning_cfg.precached_key + self.token_length = conditioning_cfg.token_length + self.embed_dim = conditioning_cfg.embed_dim + + def __getitem__(self, index): + item = {} + if isinstance(self.res, list): + for resolution in self.res: + image_key = f'images_{resolution}' + item[image_key] = torch.randn(3, resolution, resolution) + else: + item['images'] = torch.randn(3, self.res, self.res) + + item['raw_text'] = f'fake text {index}' + if not self.no_embedding: + item[f'{self.out_key}_embeddings'] = torch.randn(self.token_length, self.embed_dim) + item[f'{self.out_key}_mask'] = torch.ones(self.token_length, dtype=torch.long) + return item + + def __len__(self): + return self.fake_len + + +def _build_functions_with_pickles(data_cfg, condition_cfg): + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict['images'] = input[0] + + # Output from pickle transform is already a dictionary + out_dict.update(input[1]) + + out_dict['raw_text'] = input[2] + yield out_dict + + def transform_fn(sample): + image, encodings, text = sample['jpg'], sample['pickle'], sample['txt'] + img_transform = construct_image_augmentations(data_cfg.train.get('augmentations'), normalize=True) + pickle_transform = PickleTransform( + encoding_keys=[condition_cfg.precached_key], + encoding_lengths=[condition_cfg.token_length], + out_keys=[condition_cfg.out_key], + ) + text_transform = identical_transform + return img_transform(image), pickle_transform(encodings), text_transform(text) + + return tuple_to_dict, transform_fn + + +def _build_functions_no_pickles(data_cfg): + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict['images'] = input[0] + out_dict['raw_text'] = input[1] + yield out_dict + + def transform_fn(sample): + image, text = sample['jpg'], sample['txt'] + img_transform = construct_image_augmentations(data_cfg.train.get('augmentations'), normalize=True) + text_transform = identical_transform + return img_transform(image), text_transform(text) + + return tuple_to_dict, transform_fn + + +def build_train_valid_datasets( + model_cfg, consumed_samples, +): + data_cfg = model_cfg.data + condition_cfg = model_cfg.conditioning + + if data_cfg.get('synthetic_data', False): + logging.info(f'Creating Synthetic Datasaet.') + train_data = ImagenSyntheticDataset( + res=data_cfg.train.get('target_resolutions', 64), + conditioning_cfg=condition_cfg, + fake_len=data_cfg.get('synthetic_data_length', 10000), + no_embedding=condition_cfg.get("online_encoding", False), + ) + return train_data, None + # This function maps data that are tuples to dictionary. + if condition_cfg.get("online_encoding", False): + tuple_to_dict, transform_fn = _build_functions_no_pickles(data_cfg) + else: + tuple_to_dict, transform_fn = _build_functions_with_pickles(data_cfg, condition_cfg) + + filter_cfg = data_cfg.train.get('filterings', None) + + # For adding corruptions and obtaining image pyramid + if model_cfg.unet_type.startswith('sr'): + assert data_cfg.train.get('target_resolutions'), 'SR model requires multiple resolution for training' + logging.info(f'Resizing input images into the follow resolutions: {data_cfg.train.target_resolutions}') + corruption_gen = ImagePyramidNoCorruptions(target_resolutions=data_cfg.train.target_resolutions) + else: + corruption_gen = None + + # This function is used for obtaining image pyramid + # in SR models for Imagen, we need to use low-res image as conditioning. + def obtain_image_pyramid(inp): + for data_dict in inp: + data_pyramid = corruption_gen.obtain_image_pyramid(data_dict['images']) + data_dict.update(data_pyramid) + yield data_dict + + compose_fn = [tuple_to_dict] + if corruption_gen: + compose_fn.append(obtain_image_pyramid) + + train_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=compose_fn, + filter_fn=build_resolution_filter(**filter_cfg.resolution, image_idx='jpg') if filter_cfg else None, + is_train=True, + ) + return train_data, None diff --git a/nemo/collections/multimodal/data/instruct_pix2pix/__init__.py b/nemo/collections/multimodal/data/instruct_pix2pix/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/instruct_pix2pix/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/instruct_pix2pix/edit_dataset.py b/nemo/collections/multimodal/data/instruct_pix2pix/edit_dataset.py new file mode 100644 index 000000000000..e1ff1966d3c3 --- /dev/null +++ b/nemo/collections/multimodal/data/instruct_pix2pix/edit_dataset.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import json +import math +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from torch.utils.data import Dataset + +try: + import torchvision + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + + +class EditDataset(Dataset): + def __init__( + self, + path: str, + split: str = "train", + splits: tuple[float, float, float] = (0.95, 0.04, 0.01), + min_resize_res: int = 256, + max_resize_res: int = 256, + crop_res: int = 256, + flip_prob: float = 0.0, + ): + assert split in ("train", "val", "test") + assert sum(splits) == 1 + self.path = path + self.min_resize_res = min_resize_res + self.max_resize_res = max_resize_res + self.crop_res = crop_res + self.flip_prob = flip_prob + + with open(Path(self.path, "seeds.json")) as f: + self.seeds = json.load(f) + + split_0, split_1 = { + "train": (0.0, splits[0]), + "val": (splits[0], splits[0] + splits[1]), + "test": (splits[0] + splits[1], 1.0), + }[split] + + idx_0 = math.floor(split_0 * len(self.seeds)) + idx_1 = math.floor(split_1 * len(self.seeds)) + self.seeds = self.seeds[idx_0:idx_1] + + def __len__(self) -> int: + return len(self.seeds) + + def __getitem__(self, i: int) -> dict[str, Any]: + name, seeds = self.seeds[i] + propt_dir = Path(self.path, name) + seed = seeds[torch.randint(0, len(seeds), ()).item()] + with open(propt_dir.joinpath("prompt.json")) as fp: + prompt = json.load(fp)["edit"] + + image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) + image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) + + resize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() + image_0 = image_0.resize((resize_res, resize_res), Image.Resampling.LANCZOS) + image_1 = image_1.resize((resize_res, resize_res), Image.Resampling.LANCZOS) + + image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") + image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") + + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + crop = torchvision.transforms.RandomCrop(self.crop_res) + flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) + image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) + + return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) + + +class EditDatasetEval(Dataset): + def __init__( + self, path: str, split: str = "train", splits: tuple[float, float, float] = (0.9, 0.05, 0.05), res: int = 256, + ): + assert split in ("train", "val", "test") + assert sum(splits) == 1 + self.path = path + self.res = res + + with open(Path(self.path, "seeds.json")) as f: + self.seeds = json.load(f) + + split_0, split_1 = { + "train": (0.0, splits[0]), + "val": (splits[0], splits[0] + splits[1]), + "test": (splits[0] + splits[1], 1.0), + }[split] + + idx_0 = math.floor(split_0 * len(self.seeds)) + idx_1 = math.floor(split_1 * len(self.seeds)) + self.seeds = self.seeds[idx_0:idx_1] + + def __len__(self) -> int: + return len(self.seeds) + + def __getitem__(self, i: int) -> dict[str, Any]: + name, seeds = self.seeds[i] + propt_dir = Path(self.path, name) + seed = seeds[torch.randint(0, len(seeds), ()).item()] + with open(propt_dir.joinpath("prompt.json")) as fp: + prompt = json.load(fp) + edit = prompt["edit"] + input_prompt = prompt["input"] + output_prompt = prompt["output"] + + image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) + + reize_res = torch.randint(self.res, self.res + 1, ()).item() + image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) + + image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") + + return dict(image_0=image_0, input_prompt=input_prompt, edit=edit, output_prompt=output_prompt) diff --git a/nemo/collections/multimodal/data/nerf/__init__.py b/nemo/collections/multimodal/data/nerf/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/nerf/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/nerf/cameras.py b/nemo/collections/multimodal/data/nerf/cameras.py new file mode 100644 index 000000000000..72dbf698380c --- /dev/null +++ b/nemo/collections/multimodal/data/nerf/cameras.py @@ -0,0 +1,192 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import List + +import numpy as np +import torch + + +class Camera(ABC): + """ + Abstract base class for Camera models. + """ + + def __init__(self, width: int, height: int, device: torch.device = 'cuda') -> None: + """ + Initializes the Camera instance with given dimensions and device. + + Parameters: + width: int - Width of the camera frame. + height: int - Height of the camera frame. + device: torch.device - The device where tensor computations will be performed. + """ + self.width = width + self.height = height + self.device = device + + @abstractmethod + def compute_intrinsics(self) -> None: + """ + Abstract method to compute camera intrinsics. + """ + pass + + @abstractmethod + def compute_projection_matrix(self) -> None: + """ + Abstract method to compute the projection matrix. + """ + pass + + +class OrthographicCamera(Camera): + """ + Class for Orthographic Camera models. + """ + + def compute_projection_matrix(self) -> torch.Tensor: + """ + Computes the projection matrix for an Orthographic camera. + + Returns: + torch.Tensor: The projection matrix. + """ + projection = torch.tensor( + [[2 / self.width, 0, 0, 0], [0, -2 / self.height, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + device=self.device, + ).unsqueeze(0) + return projection + + +class PinholeCamera(Camera): + """ + Class for Pinhole Camera models. + """ + + def __init__(self, width: int, height: int, near: float, far: float, device: torch.device = 'cuda') -> None: + """ + Initializes the Pinhole Camera instance with given parameters. + + Parameters: + width: int - Width of the camera frame. + height: int - Height of the camera frame. + near: float - Near clipping plane. + far: float - Far clipping plane. + device: torch.device - The device where tensor computations will be performed. + """ + super().__init__(width, height, device) + self.near = near + self.far = far + + def compute_intrinsics(self, fovx: float, fovy: float) -> np.ndarray: + """ + Computes the intrinsic matrix for the camera based on field of views. + + Parameters: + fovx: float - Field of view in X direction. + fovy: float - Field of view in Y direction. + + Returns: + np.ndarray: The intrinsic matrix. + """ + focal_x = self.width / (2 * np.tan(np.deg2rad(fovx) / 2)) + focal_y = self.height / (2 * np.tan(np.deg2rad(fovy) / 2)) + cx, cy = self.width / 2, self.height / 2 + return np.array([focal_x, focal_y, cx, cy]) + + def compute_projection_matrix(self, focal_x: float, focal_y: float) -> torch.Tensor: + """ + Computes the projection matrix for the camera. + + Parameters: + focal_x: float - Focal length in X direction. + focal_y: float - Focal length in Y direction. + + Returns: + torch.Tensor: The projection matrix. + """ + projection = torch.tensor( + [ + [2 * focal_x / self.width, 0, 0, 0], + [0, -2 * focal_y / self.height, 0, 0], + [ + 0, + 0, + -(self.far + self.near) / (self.far - self.near), + -(2 * self.far * self.near) / (self.far - self.near), + ], + [0, 0, -1, 0], + ], + dtype=torch.float32, + device=self.device, + ).unsqueeze(0) + return projection + + +class CubeCamera(Camera): + """ + Class for Cube Camera models, which is essentially six pinhole cameras. + """ + + def __init__( + self, width: int, height: int, near: float = 0.01, far: float = 1000, device: torch.device = 'cuda' + ) -> None: + """ + Initializes the Cube Camera instance with given parameters. + + Parameters: + width: int - Width of each camera face. + height: int - Height of each camera face. + near: float - Near clipping plane. + far: float - Far clipping plane. + device: torch.device - The device where tensor computations will be performed. + """ + self.width = width + self.height = height + self.near = near + self.far = far + self.device = device + + def compute_intrinsics(self) -> List[np.ndarray]: + """ + Computes the intrinsic matrices for the six faces of the cube using a Pinhole camera model. + + Returns: + List[np.ndarray]: List of 6 intrinsic matrices, one for each face. + """ + # Similar to Pinhole but repeated six times for six faces of the cube + return [ + PinholeCamera( + width=self.width, height=self.height, near=self.near, far=self.far, device=self.device + ).compute_intrinsics(90, 90) + for _ in range(6) + ] + + def compute_projection_matrix(self) -> List[torch.Tensor]: + """ + Computes the projection matrices for the six faces of the cube using a Pinhole camera model. + + Returns: + List[torch.Tensor]: List of 6 projection matrices, one for each face. + """ + # Similar to Pinhole but repeated six times for six faces of the cube + return [ + PinholeCamera( + width=self.width, height=self.height, near=self.near, far=self.far, device=self.device + ).compute_projection_matrix(1, 1) + for _ in range(6) + ] diff --git a/nemo/collections/multimodal/data/nerf/circle_poses.py b/nemo/collections/multimodal/data/nerf/circle_poses.py new file mode 100644 index 000000000000..93f1c968a018 --- /dev/null +++ b/nemo/collections/multimodal/data/nerf/circle_poses.py @@ -0,0 +1,228 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import numpy as np +import torch +from torch.utils.data import Dataset + +from nemo.collections.multimodal.data.nerf.cameras import PinholeCamera +from nemo.collections.multimodal.data.nerf.utils import ( + compute_look_at_vectors, + construct_poses, + get_rays, + get_view_direction, +) + + +def circle_poses( + radius: torch.Tensor = torch.tensor([3.2]), + theta: torch.Tensor = torch.tensor([60]), + phi: torch.Tensor = torch.tensor([0]), + angle_overhead: float = 30, + angle_front: float = 60, + return_dirs: bool = False, + device: torch.device = "cuda", +) -> torch.Tensor: + """ + Generate camera poses based on a circular arrangement. + + Parameters: + radius: torch.Tensor - Radii for the camera positions. + theta: torch.Tensor - Theta angles for the camera positions. + phi: torch.Tensor - Phi angles for the camera positions. + angle_overhead: float - Angle range of the overhead view. + angle_front: float - Angle range of the front view. + return_dirs: bool - Whether to return the view directions. + device: str - The device to allocate the tensor on (e.g., 'cuda' or 'cpu'). + + Returns: + Tuple: Contains the following: + - poses (torch.Tensor): Generated poses, shape [size, 4, 4]. + - dirs (torch.Tensor, optional): View directions, if requested. + """ + # Convert degrees to radians for theta and phi + theta = theta / 180 * np.pi + phi = phi / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + # Calculate camera centers in Cartesian coordinates + centers = torch.stack( + [ + radius * torch.sin(theta) * torch.sin(phi), + radius * torch.cos(theta), + radius * torch.sin(theta) * torch.cos(phi), + ], + dim=-1, + ) # [B, 3] + + # Compute camera look-at matrix + forward_vector, up_vector, right_vector = compute_look_at_vectors(centers=centers, device=device) + + # Construct the 4x4 pose matrices + poses = construct_poses( + centers=centers, right_vector=right_vector, up_vector=up_vector, forward_vector=forward_vector, device=device + ) + + dirs = get_view_direction(theta, phi, angle_overhead, angle_front) if return_dirs else None + + return poses, dirs + + +class CirclePosesDataset(Dataset): + """ + A dataset class to generate circle poses. + """ + + def __init__( + self, + size: int = 100, + height: int = 256, + width: int = 256, + default_fovx: float = 20.0, + default_fovy: float = 20.0, + default_radius: float = 3.2, + default_polar: float = 90.0, + default_azimuth: float = 0.0, + angle_overhead: float = 30.0, + angle_front: float = 60.0, + near: float = 0.01, + far: float = 1000.0, + device: torch.device = 'cpu', + ) -> None: + """ + Initializes a new CirclePosesDataset instance. + + Parameters: + size (int): Number of samples in the dataset. + height (int): Height of the image. + width (int): Width of the image. + default_fovx (float): Default field of view in x-direction. + default_fovy (float): Default field of view in y-direction. + default_radius (float): Default radius of the circle. + default_polar (float): Default polar angle. + default_azimuth (float): Default azimuth angle. + angle_overhead (float): Overhead angle. + angle_front (float): Frontal angle. + near (float): Near clipping distance. + far (float): Far clipping distance. + device (torch.device): Device to generate data on. + """ + super().__init__() + self.size = size + self.height = height + self.width = width + + self.default_fovx = default_fovx + self.default_fovy = default_fovy + self.default_radius = default_radius + self.default_polar = default_polar + self.default_azimuth = default_azimuth + + self.angle_overhead = angle_overhead + self.angle_front = angle_front + self.near = near + self.far = far + + self.device = device + + # TODO(ahmadki): make camera type a parameter + self.camera = PinholeCamera( + width=self.width, height=self.height, near=self.near, far=self.far, device=self.device + ) + + def __len__(self) -> int: + """Returns the number of samples in the dataset.""" + return self.size + + def __getitem__(self, idx: int) -> Dict[str, Union[int, torch.Tensor]]: + """Get an item from the dataset. + + Args: + idx (int): Index of the item to retrieve. + + Returns: + dict: Data dictionary containing the following: + - height (int): Height of the image. + - width (int): Width of the image. + - rays_o (torch.Tensor): Ray origins, shape [height, width, 3]. + - rays_d (torch.Tensor): Ray directions, shape [height, width, 3]. + - dir (torch.Tensor): View direction, shape [3]. + - mvp (torch.Tensor): Model-view-projection matrix, shape [4, 4]. + - azimuth (torch.Tensor): Azimuth angle, shape [1]. + """ + # Initialize circle pose parameters + thetas = torch.FloatTensor([self.default_polar]).to(self.device) + phis = torch.FloatTensor([(idx / self.size) * 360]).to(self.device) + radius = torch.FloatTensor([self.default_radius]).to(self.device) + + # Generate circle poses and directions + poses, dirs = circle_poses( + radius=radius, + theta=thetas, + phi=phis, + angle_overhead=self.angle_overhead, + angle_front=self.angle_front, + return_dirs=True, + device=self.device, + ) + + # Compute camera intrinsics + intrinsics = self.camera.compute_intrinsics(fovx=self.default_fovx, fovy=self.default_fovy) + + # Compute projection matrix + projection = self.camera.compute_projection_matrix(focal_x=intrinsics[0], focal_y=intrinsics[1]) + mvp = projection @ torch.inverse(poses) # [1, 4, 4] + + # Sample rays + rays_o, rays_d = get_rays( + poses=poses, intrinsics=intrinsics, height=self.height, width=self.width, device=poses.device + ) + + # Compute azimuth delta + delta_azimuth = phis - self.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + + data = { + 'height': self.height, + 'width': self.width, + 'rays_o': rays_o, + 'rays_d': rays_d, + 'dir': dirs, + 'mvp': mvp, + 'azimuth': delta_azimuth, + } + + return data + + def collate_fn(self, batch: list) -> Dict[str, Union[int, torch.Tensor]]: + """Collate function to combine multiple data points into batches. + + Args: + batch (list): List of data dictionaries. + + Returns: + dict: Collated data. + """ + return { + 'height': self.height, + 'width': self.width, + 'rays_o': torch.cat([item['rays_o'] for item in batch], dim=0), + 'rays_d': torch.cat([item['rays_d'] for item in batch], dim=0), + 'mvp': torch.cat([item['mvp'] for item in batch], dim=0), + 'dir': torch.cat([item['dir'] for item in batch], dim=0), + 'azimuth': torch.cat([item['azimuth'] for item in batch], dim=0), + } diff --git a/nemo/collections/multimodal/data/nerf/random_poses.py b/nemo/collections/multimodal/data/nerf/random_poses.py new file mode 100644 index 000000000000..7ecc56228698 --- /dev/null +++ b/nemo/collections/multimodal/data/nerf/random_poses.py @@ -0,0 +1,450 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.data import IterableDataset + +from nemo.collections.multimodal.data.nerf.cameras import PinholeCamera +from nemo.collections.multimodal.data.nerf.utils import ( + compute_look_at_vectors, + construct_poses, + get_rays, + get_view_direction, +) + + +def linear_normalization(x: float, lower_bound: float, upper_bound: float) -> float: + """ + Linearly normalize a value between lower_bound and upper_bound to a value between 0 and 1. + + Parameters: + x: The value to normalize. + lower_bound: The lower bound of the range of x. + upper_bound: The upper bound of the range of x. + + Returns: + The normalized value between 0 and 1. + """ + return min(1, max(0, (x - lower_bound) / (upper_bound - lower_bound))) + + +def rand_poses( + size: int, + radius_range: List[float] = [1, 1.5], + theta_range: List[float] = [0, 120], + phi_range: List[float] = [0, 360], + angle_overhead: float = 30, + angle_front: float = 60, + uniform_sphere_rate: float = 0.5, + jitter: bool = False, + jitter_center: float = 0.2, + jitter_target: float = 0.2, + jitter_up: float = 0.02, + return_dirs: bool = False, + device: torch.device = "cuda", +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Generate random poses from an orbit camera. + + Args: + size (int): Number of poses to generate. + radius_range (List[float]): Min and max radii for camera [min, max]. + theta_range (List[float]): Elevation angle range in degrees [min, max]. + phi_range (List[float]): Azimuth angle range in degrees [min, max]. + angle_overhead (float): Overhead angle in degrees. + angle_front (float): Front angle in degrees. + uniform_sphere_rate (float): The probability of sampling from a uniform sphere. + jitter (bool): Whether to add noise to the poses. + jitter_center (float): Noise range for the camera center. + jitter_target (float): Noise range for the camera target. + jitter_up (float): Noise range for the camera up vector. + return_dirs (bool): Whether to return the view directions. + device (torch.device): The device on which to allocate tensors. + + Returns: + Tuple: Contains the following: + - poses (torch.Tensor): Generated poses, shape [size, 4, 4]. + - thetas (torch.Tensor): Elevation angles in degrees, shape [size]. + - phis (torch.Tensor): Azimuth angles in degrees, shape [size]. + - radius (torch.Tensor): Radii of the camera orbits, shape [size]. + - dirs (torch.Tensor, optional): View directions, if requested. + """ + + # Convert angles from degrees to radians + theta_range = np.radians(theta_range) + phi_range = np.radians(phi_range) + angle_overhead = np.radians(angle_overhead) + angle_front = np.radians(angle_front) + + # Generate radius for each pose + radius = torch.rand(size, device=device) * (radius_range[1] - radius_range[0]) + radius_range[0] + + # Generate camera center positions + if random.random() < uniform_sphere_rate: + centers, thetas, phis = sample_uniform_sphere(size=size, radius=radius, device=device) + else: + centers, thetas, phis = sample_orbit( + size=size, radius=radius, theta_range=theta_range, phi_range=phi_range, device=device + ) + + # Initialize targets to 0 (assuming 0 is a point in 3D space that cameras are looking at) + targets = torch.zeros_like(centers) + + # Apply jitter + if jitter: + centers += torch.rand_like(centers) * jitter_center - jitter_center / 2.0 + targets = torch.randn_like(centers) * jitter_target + + # Compute camera look-at matrix + forward_vector, up_vector, right_vector = compute_look_at_vectors( + centers=centers - targets, jitter_up=jitter_up if jitter else 0, device=device + ) + + # Construct the 4x4 pose matrices + poses = construct_poses( + centers=centers, right_vector=right_vector, up_vector=up_vector, forward_vector=forward_vector, device=device + ) + + # Optionally compute view directions + dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) if return_dirs else None + + # Convert back to degrees for thetas and phis + thetas, phis = torch.rad2deg(thetas), torch.rad2deg(phis) + + return poses, thetas, phis, radius, dirs + + +def sample_uniform_sphere( + size: int, radius: torch.Tensor, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample points uniformly on a sphere. + + Args: + size (int): Number of points to sample. + device (torch.device): Device to allocate tensors on. + radius (torch.Tensor): Radii for the points. + + Returns: + Tuple: Contains the following: + - centers (torch.Tensor): The Cartesian coordinates of the sampled points. + - thetas (torch.Tensor): Elevation angles in radians. + - phis (torch.Tensor): Azimuth angles in radians. + """ + # Generate unit vectors + unit_centers = F.normalize( + torch.stack( + [ + torch.randn(size, device=device), + torch.abs(torch.randn(size, device=device)), + torch.randn(size, device=device), + ], + dim=-1, + ), + p=2, + dim=1, + ) + # Generate radii and scale unit vectors + centers = unit_centers * radius.unsqueeze(-1) + # Calculate spherical coordinates + thetas = torch.acos(unit_centers[:, 1]) + phis = torch.atan2(unit_centers[:, 0], unit_centers[:, 2]) + phis[phis < 0] += 2 * np.pi + + return centers, thetas, phis + + +def sample_orbit( + size: int, radius: torch.Tensor, theta_range: np.ndarray, phi_range: np.ndarray, device: torch.device = "cuda" +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample points on a spherical orbit. + + Args: + size (int): Number of points to sample. + radius (torch.Tensor): Radii for the points. + theta_range (np.ndarray): Elevation angle range in radians [min, max]. + phi_range (np.ndarray): Azimuth angle range in radians [min, max]. + device (torch.device): Device to allocate tensors on. + + Returns: + Tuple: Contains the following: + - centers (torch.Tensor): The Cartesian coordinates of the sampled points. + - thetas (torch.Tensor): Elevation angles in radians. + - phis (torch.Tensor): Azimuth angles in radians. + """ + thetas = torch.rand(size, device=device) * (theta_range[1] - theta_range[0]) + theta_range[0] + phis = torch.rand(size, device=device) * (phi_range[1] - phi_range[0]) + phi_range[0] + phis[phis < 0] += 2 * np.pi + + x = radius * torch.sin(thetas) * torch.sin(phis) + y = radius * torch.cos(thetas) + z = radius * torch.sin(thetas) * torch.cos(phis) + + centers = torch.stack([x, y, z], dim=-1) + + return centers, thetas, phis + + +class RandomPosesDataset(IterableDataset): + """ + A dataset class to generate random poses. + """ + + def __init__( + self, + internal_batch_size: int = 100, + height: int = 256, + width: int = 256, + radius_range: Tuple[float, float] = [3.0, 3.5], + theta_range: Tuple[float, float] = [45.0, 105.0], + phi_range: Tuple[float, float] = [-180.0, 180.0], + fovx_range: Tuple[float, float] = [10.0, 30.0], + default_fovx: float = 20.0, + fovy_range: Tuple[float, float] = [10.0, 30.0], + default_fovy: float = 20.0, + default_radius: float = 3.2, + default_polar: float = 90.0, + default_azimuth: float = 0.0, + jitter: bool = False, + jitter_center: float = 0.2, + jitter_target: float = 0.2, + jitter_up: float = 0.02, + angle_overhead: float = 30.0, + angle_front: float = 60.0, + uniform_sphere_rate: float = 0.0, + near: float = 0.01, + far: float = 1000.0, + device: torch.device = 'cpu', + ) -> None: + """ + Initializes a new RandomPosesDataset instance. + + Parameters: + internal_batch_size (int): Number of samples to pre-generate internally. + height (int): Height of the image. + width (int): Width of the image. + radius_range (Tuple[float, float]): Range of generated radii. + theta_range (Tuple[float, float]): Range of generated theta angles. + phi_range (Tuple[float, float]): Range of generated phi angles. + fovx_range (Tuple[float, float]): Range of generated field of view in x-direction. + default_fovx (float): Default field of view in x-direction. + fovy_range (Tuple[float, float]): Range of generated field of view angles in y-direction. + default_fovy (float): Default field of view in y-direction. + default_radius (float): Default radius of the circle. + default_polar (float): Default polar angle. + default_azimuth (float): Default azimuth angle. + jitter (bool): Whether to jitter the poses. + jitter_center (float): Jittering center range. + jitter_target (float): Jittering target range. + jitter_up (float): Jittering up range. + angle_overhead (float): Overhead angle. + angle_front (float): Frontal angle. + uniform_sphere_rate (float): Rate of sampling uniformly on a sphere. + near (float): Near clipping distance. + far (float): Far clipping distance. + device (torch.device): Device to generate data on. + """ + + super().__init__() + self.height = height + self.width = width + self.internal_batch_size = internal_batch_size + + # TODO(ahmadki): expose for models other than dreamfusion + self.progressive_view = False + self.progressive_view_start_step = 0 + self.progressive_view_end_step = 500 + + self.default_fovx = default_fovx + self.default_fovy = default_fovy + self.default_radius = default_radius + self.default_polar = default_polar + self.default_azimuth = default_azimuth + self.same_fov_random = True + + self.radius_range = radius_range + self.theta_range = theta_range + self.phi_range = phi_range + self.fovx_range = fovx_range + self.fovy_range = fovy_range + + self.current_radius_range = radius_range + self.current_theta_range = theta_range + self.current_phi_range = phi_range + self.current_fovx_range = fovx_range + self.current_fovy_range = fovy_range + + self.angle_overhead = angle_overhead + self.angle_front = angle_front + self.uniform_sphere_rate = uniform_sphere_rate + self.jitter = jitter + self.jitter_center = jitter_center + self.jitter_target = jitter_target + self.jitter_up = jitter_up + + self.near = near + self.far = far + + self.device = device + + # TODO(ahmadki): make camera type a parameter + self.camera = PinholeCamera( + width=self.width, height=self.height, near=self.near, far=self.far, device=self.device + ) + + def update_step(self, epoch: int, global_step: int) -> None: + """ + Update the dataset at the beginning of each epoch. + + Parameters: + epoch (int): Current epoch. + global_step (int): Current global step. + + """ + if self.progressive_view: + self.progressive_view_update_step(global_step=global_step) + + def progressive_view_update_step(self, global_step: int) -> None: + """ + progressively relaxing view range + + Parameters: + global_step (int): Current global step. + """ + # TODO(ahmadki): support non-linear progressive_views + r = linear_normalization( + x=global_step, lower_bound=self.progressive_view_start_step, upper_bound=self.progressive_view_end_step + ) + self.current_phi_range = [ + (1 - r) * self.default_azimuth + r * self.phi_range[0], + (1 - r) * self.default_azimuth + r * self.phi_range[1], + ] + self.current_theta_range = [ + (1 - r) * self.default_polar + r * self.theta_range[0], + (1 - r) * self.default_polar + r * self.theta_range[1], + ] + self.current_radius_range = [ + (1 - r) * self.default_radius + r * self.radius_range[0], + (1 - r) * self.default_radius + r * self.radius_range[1], + ] + self.current_fovy_range = [ + (1 - r) * self.default_fovy + r * self.fovy_range[0], + (1 - r) * self.default_fovy + r * self.fovy_range[1], + ] + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + """ + Returns an iterator over the dataset. + + Returns: + Iterator: An iterator over the dataset. + + """ + while True: + # Generate samples + rays_o, rays_d, dirs, mvp, delta_azimuth = self.generate_samples() + for i in range(self.internal_batch_size): + # Yield one sample at a time from the internal batch + yield { + 'height': self.height, + 'width': self.width, + 'rays_o': rays_o[i].unsqueeze(0), + 'rays_d': rays_d[i].unsqueeze(0), + 'dir': dirs[i].unsqueeze(0), + 'mvp': mvp[i].unsqueeze(0), + 'azimuth': delta_azimuth[i].unsqueeze(0), + } + + def generate_samples(self): + """ + Generate a batch of random poses. + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + A tuple containing: + - rays (Dict[str, torch.Tensor]): A dictionary containing the origin and direction of the rays. + - dirs (torch.Tensor): A tensor containing the directions of the rays. + - mvp (torch.Tensor): A tensor containing the model-view-projection matrix. + - azimuth (torch.Tensor): A A tensor containing the azimuth angle. + """ + # Generate random poses and directions + poses, dirs, thetas, phis, radius = rand_poses( + size=self.internal_batch_size, + radius_range=self.current_radius_range, + theta_range=self.current_theta_range, + phi_range=self.current_phi_range, + angle_overhead=self.angle_overhead, + angle_front=self.angle_front, + uniform_sphere_rate=self.uniform_sphere_rate, + jitter=self.jitter, + jitter_center=self.jitter_center, + jitter_target=self.jitter_target, + jitter_up=self.jitter_up, + return_dirs=True, + device=self.device, + ) + + # random focal + if self.same_fov_random: + fovx_random = random.random() + fovy_random = fovx_random + else: + fovx_random = random.random() + fovy_random = random.random() + fovx = fovx_random * (self.current_fovx_range[1] - self.current_fovx_range[0]) + self.current_fovx_range[0] + fovy = fovy_random * (self.current_fovy_range[1] - self.current_fovy_range[0]) + self.current_fovy_range[0] + + # Compute camera intrinsics + intrinsics = self.camera.compute_intrinsics(fovx=fovx, fovy=fovy) + + # Compute projection matrix + projection = self.camera.compute_projection_matrix(focal_x=intrinsics[0], focal_y=intrinsics[1]) + mvp = projection @ torch.inverse(poses) # [internal batch size, 4, 4] + + # Sample rays + rays_o, rays_d = get_rays( + poses=poses, intrinsics=intrinsics, height=self.height, width=self.width, device=poses.device + ) + + # Compute azimuth delta + delta_azimuth = phis - self.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + + return rays_o, rays_d, dirs, mvp, delta_azimuth + + def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate function to bundle multiple samples into a single batch. + + Args: + batch (List[Dict]): List of samples to collate. + + Returns: + Dict: A dictionary containing the collated batch. + """ + return { + 'height': self.height, + 'width': self.width, + 'rays_o': torch.cat([item['rays_o'] for item in batch], dim=0), + 'rays_d': torch.cat([item['rays_d'] for item in batch], dim=0), + 'mvp': torch.cat([item['mvp'] for item in batch], dim=0), + 'dir': torch.cat([item['dir'] for item in batch], dim=0), + 'azimuth': torch.cat([item['azimuth'] for item in batch], dim=0), + } diff --git a/nemo/collections/multimodal/data/nerf/utils.py b/nemo/collections/multimodal/data/nerf/utils.py new file mode 100644 index 000000000000..306aeb546f57 --- /dev/null +++ b/nemo/collections/multimodal/data/nerf/utils.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn.functional as F + + +def get_view_direction(thetas: torch.Tensor, phis: torch.Tensor, overhead: float, front: float) -> torch.Tensor: + """ + Get the view direction based on given theta and phi values. + + Parameters: + - thetas (torch.Tensor): Array of theta values with shape [B,] + - phis (torch.Tensor): Array of phi values with shape [B,] + - overhead (float): Threshold for determining top and bottom views. + - front (float): Threshold for determining front, back and side views. + + Returns: + - torch.Tensor: Array of view directions. Values can be: + 0: front + 1: side (camera left) + 2: back + 3: side (camera right) + 4: top + 5: bottom + + Notes: + - Phi and theta values are assumed to be in radians. + """ + + num_samples = thetas.shape[0] + res = torch.zeros(num_samples, dtype=torch.long) + + # Normalize phis values to [0, 2*pi] + phis = phis % (2 * np.pi) + + # Determine direction based on phis + res[(phis < front / 2) | (phis >= 2 * np.pi - front / 2)] = 0 + res[(phis >= front / 2) & (phis < np.pi - front / 2)] = 1 + res[(phis >= np.pi - front / 2) & (phis < np.pi + front / 2)] = 2 + res[(phis >= np.pi + front / 2) & (phis < 2 * np.pi - front / 2)] = 3 + + # Override directions based on thetas for top and bottom views + res[thetas <= overhead] = 4 + res[thetas >= (np.pi - overhead)] = 5 + + return res + + +def compute_look_at_vectors(centers: torch.Tensor, jitter_up: Optional[float] = None, device: torch.device = "cuda"): + """ + Compute the look-at vectors for camera poses. + + Parameters: + centers: The centers of the cameras. + jitter_up: The noise range for the up vector of the camera. + device: Device to allocate the output tensor. + + Returns: + Tuple: Contains the following: + - forward_vector: The forward vectors of the cameras, shape [B, 3]. + - up_vector: The up vectors of the cameras, shape [B, 3]. + - right_vector: The right vectors of the cameras, shape [B, 3]. + """ + forward_vector = F.normalize(centers) + up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1) + right_vector = F.normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_noise = torch.randn_like(up_vector) * jitter_up if jitter_up is not None else 0 + up_vector = F.normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) + + return forward_vector, up_vector, right_vector + + +def construct_poses( + centers: torch.Tensor, + right_vector: torch.Tensor, + up_vector: torch.Tensor, + forward_vector: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """ + Construct the 4x4 pose matrices. + + Args: + size (int): Number of pose matrices to construct. + centers (torch.Tensor): The Cartesian coordinates of the camera centers. + right_vector (torch.Tensor): The right vectors of the cameras. + up_vector (torch.Tensor): The up vectors of the cameras. + forward_vector (torch.Tensor): The forward vectors of the cameras. + device (torch.device): Device to allocate tensors on. + + Returns: + torch.Tensor: The pose matrices, shape [size, 4, 4]. + """ + poses = torch.eye(4, dtype=torch.float32, device=device).unsqueeze(0).repeat(len(centers), 1, 1) + poses[:, :3, :3] = torch.stack([right_vector, up_vector, forward_vector], dim=-1) + poses[:, :3, 3] = centers + + return poses + + +@torch.cuda.amp.autocast(enabled=False) +def get_rays( + poses: torch.Tensor, + intrinsics: torch.Tensor, + height: int, + width: int, + num_samples: Optional[int] = None, + error_map: Optional[torch.Tensor] = None, + device: torch.device = "cuda", +) -> Dict[str, torch.Tensor]: + """ + Generates rays from camera poses and intrinsics. + + Args: + poses (torch.Tensor): Camera poses, shape [B, 4, 4] (cam2world). + intrinsics (torch.Tensor): Intrinsic camera parameters [fx, fy, cx, cy]. + height (int): Height of the image. + width (int): Width of the image. + num_samples: Number of rays to sample, default is None for all rays. + error_map: Optional tensor to use for non-uniform sampling of rays. + device (torch.device): Device on which to generate the rays. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the following keys: + - 'rays_o': Origin of the rays, shape [B, N, 3] + - 'rays_d': Directions of the rays, shape [B, N, 3] + - 'inds': Indices of the rays, shape [B, N] (if N > 0) + - 'inds_coarse': Coarse indices of the rays, shape [B, N] (if error_map is not None) + """ + + batch_size = poses.shape[0] + fx, fy, cx, cy = intrinsics + + i, j = torch.meshgrid( + torch.linspace(0, width - 1, width, device=device), + torch.linspace(0, height - 1, height, device=device), + indexing='ij', + ) + i = i.t().reshape([1, height * width]).expand([batch_size, height * width]) + 0.5 + j = j.t().reshape([1, height * width]).expand([batch_size, height * width]) + 0.5 + + results = {} + + if num_samples is not None: + num_samples = min(num_samples, height * width) + + if error_map is None: + sampled_indices = torch.randint(0, height * width, size=[num_samples], device=device) + sampled_indices = sampled_indices.expand([batch_size, num_samples]) + else: + sampled_indices, sampled_indices_coarse = non_uniform_sampling( + error_map=error_map, num_samples=num_samples, height=height, width=width, device=device + ) + results['sampled_indices_coarse'] = sampled_indices_coarse + + i = torch.gather(i, -1, sampled_indices) + j = torch.gather(j, -1, sampled_indices) + results['sampled_indices'] = sampled_indices + else: + sampled_indices = torch.arange(height * width, device=device).expand([batch_size, height * width]) + + zs = torch.full_like(i, -1.0) + xs = -(i - cx) / fx * zs + ys = (j - cy) / fy * zs + directions = torch.stack((xs, ys, zs), dim=-1) + + rays_d = directions @ poses[:, :3, :3].transpose(-1, -2) + rays_o = poses[..., :3, 3].unsqueeze(-2).expand_as(rays_d) + + rays_o = rays_o.view(-1, height, width, 3) + rays_d = rays_d.view(-1, height, width, 3) + + return rays_o, rays_d + + +def non_uniform_sampling( + error_map: torch.Tensor, batch_size: int, num_samples: int, height: int, width: int, device: torch.device = "cuda" +) -> torch.Tensor: + """ + Perform non-uniform sampling based on the provided error_map. + + Parameters: + error_map: The error map for non-uniform sampling. + batch_size (int): Batch size of the generated samples. + num_samples (int): Number of samples to pick. + height (int): Height of the image. + width (int): Width of the image. + device: Device on which tensors are stored. + + Returns: + A tensor containing the sampled indices. + """ + + sampled_indices_coarse = torch.multinomial(error_map.to(device), num_samples, replacement=False) + inds_x, inds_y = sampled_indices_coarse // 128, sampled_indices_coarse % 128 + sx, sy = height / 128, width / 128 + + inds_x = (inds_x * sx + torch.rand(batch_size, num_samples, device=device) * sx).long().clamp(max=height - 1) + inds_y = (inds_y * sy + torch.rand(batch_size, num_samples, device=device) * sy).long().clamp(max=width - 1) + sampled_indices = inds_x * width + inds_y + + return sampled_indices, sampled_indices_coarse diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 6fd87f712b7a..4e53eb5190f6 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -72,15 +72,6 @@ def get_prompt(self): ret += role + ": " + message + seps[i % 2] else: ret += role + ":" - elif self.sep_style == SeparatorStyle.MPT: - ret = self.system + self.sep - for role, message in messages: - if message: - if type(message) is tuple: - message, _, _ = message - ret += role + message + self.sep - else: - ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" wrap_inst = lambda msg: f"[INST] {msg} [/INST]" diff --git a/nemo/collections/multimodal/data/stable_diffusion/__init__.py b/nemo/collections/multimodal/data/stable_diffusion/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/stable_diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/stable_diffusion/augmentation/__init__.py b/nemo/collections/multimodal/data/stable_diffusion/augmentation/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/data/stable_diffusion/augmentation/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/data/stable_diffusion/augmentation/augmentations.py b/nemo/collections/multimodal/data/stable_diffusion/augmentation/augmentations.py new file mode 100644 index 000000000000..3fb8a1d3959f --- /dev/null +++ b/nemo/collections/multimodal/data/stable_diffusion/augmentation/augmentations.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + import torchvision.transforms as transforms + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + + +def construct_clip_augmentations(n_px=224): + def _convert_image_to_rgb(image): + return image.convert("RGB") + + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + return transforms.Compose( + [ + transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(n_px), + _convert_image_to_rgb, + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + + +def construct_image_augmentations(augmentation_dict, normalize=True): + train_img_transform = [] + for aug in augmentation_dict: + if aug == 'resize_smallest_side': + img_size = int(augmentation_dict[aug]) + train_img_transform.append( + transforms.Resize(img_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True) + ) + + elif aug == 'center_crop_h_w': + img_w, img_h = augmentation_dict[aug].split(',') + img_w = int(img_w) + img_h = int(img_h) + train_img_transform.append(transforms.CenterCrop((img_w, img_h))) + + elif aug == 'random_crop_h_w': + img_w, img_h = augmentation_dict[aug].split(',') + img_w = int(img_w) + img_h = int(img_h) + train_img_transform.append(transforms.RandomCrop((img_w, img_h))) + + elif aug == 'horizontal_flip': + enabled = augmentation_dict[aug] + if enabled: + train_img_transform.append(transforms.RandomHorizontalFlip(p=0.5)) + else: + raise ValueError('Augmentation not supported') + + # Always need to convert data to tensor + train_img_transform.append(transforms.ToTensor()) + if normalize: + train_img_transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) + train_img_transform = transforms.Compose(train_img_transform) + return train_img_transform + + +def identical_transform(x): + return x diff --git a/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py b/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py new file mode 100644 index 000000000000..445932124718 --- /dev/null +++ b/nemo/collections/multimodal/data/stable_diffusion/stable_diffusion_dataset.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.data.common.webdataset import WebDatasetCommon +from nemo.collections.multimodal.data.stable_diffusion.augmentation.augmentations import ( + construct_image_augmentations, + identical_transform, +) +from nemo.core.classes import Dataset as NeMoDataset +from nemo.utils import logging + + +class SDSyntheticDataset(NeMoDataset): + def __init__( + self, image_H, image_W, fake_len=100000, image_key='images', txt_key='txt', seq_len=80, context_dim=768 + ): + super().__init__() + self.fake_len = fake_len + self.H = image_H + self.W = image_W + self.image_key = image_key + self.txt_key = txt_key + assert image_key.endswith('encoded') == txt_key.endswith( + 'encoded' + ), 'In precached mode, first and second stage key must both end with "encoded"' + self.precached = self.image_key.endswith('encoded') + self.seq_len = seq_len + self.context_dim = context_dim + + def __getitem__(self, index): + item = {} + if self.precached: + item[self.image_key] = torch.randn(8, self.H // 8, self.W // 8) + item[self.txt_key] = torch.randn(self.seq_len, self.context_dim) + else: + item[self.image_key] = torch.randn(self.H, self.W, 3) + item[self.txt_key] = f'This is meaningless fake text No.{index}' + + return item + + def __len__(self): + return self.fake_len + + +def build_train_valid_datasets( + model_cfg, consumed_samples, +): + data_cfg = model_cfg.data + + def build_resolution_filter(value=None, method='larger'): + assert method == 'larger' or method == 'smaller' + if method == 'larger': + logging.info(f'Only Selecting images with resolution >= {value}') + return lambda x: x['jpg'].size[0] >= value and x['jpg'].size[1] >= value + logging.info(f'Only Selecting images with resolution <= {value}') + return lambda x: x['jpg'].size[0] <= value and x['jpg'].size[1] <= value + + # This function maps data that are tuples to dictionary. + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict[model_cfg.first_stage_key] = input[0].permute(1, 2, 0) + out_dict[model_cfg.cond_stage_key] = input[1] + yield out_dict + + def transform_fn(sample): + image, text = sample["jpg"], sample["txt"] + # TODO : If no agumentations just return the image ? + img_transform = construct_image_augmentations(data_cfg.train.get("augmentations", None)) + text_transform = identical_transform + return img_transform(image), text_transform(text) + + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + train_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + fake_len=data_cfg.synthetic_data_length, + ) + + else: + filter_cfg = data_cfg.train.get('filterings', None) + filter_fn = build_resolution_filter(**filter_cfg.resolution) if filter_cfg else None + train_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + filter_fn=filter_fn, + is_train=True, + ) + + val_data = None + if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): + if data_cfg.get('synthetic_data', False): + val_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + ) + else: + val_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + filter_fn=filter_fn, + is_train=False, + ) + + return train_data, val_data + + +def build_train_valid_precached_datasets( + model_cfg, consumed_samples, +): + data_cfg = model_cfg.data + + # This function maps data that are tuples to dictionary. + def tuple_to_dict(inp): + for input in inp: + out_dict = dict() + out_dict[model_cfg.first_stage_key] = torch.tensor(input['autoencoderkl_image']) + out_dict[model_cfg.cond_stage_key] = torch.tensor(input['clip-vit-large-patch14_text']) + yield out_dict + + def transform_fn(sample): + return sample['pickle'] + + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + train_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + ) + else: + train_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=True, + ) + + val_data = None + if data_cfg.get("validation") is not None and data_cfg.validation.get("data_path"): + if data_cfg.get('synthetic_data', False): + H, W = data_cfg.train.augmentations.center_crop_h_w.split(',') + train_data = SDSyntheticDataset( + int(H), + int(W), + image_key=model_cfg.first_stage_key, + txt_key=model_cfg.cond_stage_key, + context_dim=model_cfg.unet_config.context_dim, + ) + else: + val_data = WebDatasetCommon( + dataset_cfg=data_cfg, + consumed_samples=consumed_samples, + map_fn=transform_fn, + compose_fn=tuple_to_dict, + is_train=False, + ) + + return train_data, val_data diff --git a/nemo/collections/multimodal/models/nerf/__init__.py b/nemo/collections/multimodal/models/nerf/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/nerf/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/nerf/base.py b/nemo/collections/multimodal/models/nerf/base.py new file mode 100644 index 000000000000..e07d09b81c21 --- /dev/null +++ b/nemo/collections/multimodal/models/nerf/base.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.classes.common import Serialization +from nemo.core.classes.modelPT import ModelPT + + +class NerfModelBase(ModelPT, Serialization): + def __init__(self, cfg): + super().__init__(cfg=cfg) + self.save_hyperparameters() + self._cfg = cfg + + @staticmethod + def is_module_updatable(module): + return hasattr(module, 'update_step') and callable(module.update_step) + + def list_available_models(self): + pass + + def setup_training_data(self, config): + pass + + def setup_validation_data(self, config): + pass diff --git a/nemo/collections/multimodal/models/nerf/dreamfusion.py b/nemo/collections/multimodal/models/nerf/dreamfusion.py new file mode 100644 index 000000000000..27877a68abb2 --- /dev/null +++ b/nemo/collections/multimodal/models/nerf/dreamfusion.py @@ -0,0 +1,325 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import random + +import cv2 +import imageio +import numpy as np +import torch + +from nemo.collections.multimodal.models.nerf.txt2nerf_base import Txt2NerfBase +from nemo.collections.multimodal.modules.nerf.loss.laplacian_smooth_loss import LaplacianSmoothLoss +from nemo.collections.multimodal.modules.nerf.loss.normal_consistency_loss import NormalConsistencyLoss +from nemo.collections.multimodal.modules.nerf.materials.materials_base import ShadingEnum +from nemo.core import optim + + +# TODO(ahmadki): split dmtet from dreamfusion +class DreamFusion(Txt2NerfBase): + def __init__(self, cfg): + super(DreamFusion, self).__init__(cfg) + + self.guidance_scale = cfg.guidance_scale + + self.iters = cfg.iters + self.latent_iter_ratio = cfg.latent_iter_ratio + self.albedo_iter_ratio = cfg.albedo_iter_ratio + self.min_ambient_ratio = cfg.min_ambient_ratio + self.textureless_ratio = cfg.textureless_ratio + + # Lambdas + self.lambda_sds = cfg.loss.lambda_sds + self.lambda_opacity = cfg.loss.lambda_opacity + self.lambda_entropy = cfg.loss.lambda_entropy + self.lambda_orientation = cfg.loss.lambda_orientation + self.lambda_2d_normal_smooth = cfg.loss.lambda_2d_normal_smooth + self.lambda_3d_normal_smooth = cfg.loss.lambda_3d_normal_smooth + self.lambda_mesh_normal = cfg.loss.lambda_mesh_normal + self.lambda_mesh_laplacian = cfg.loss.lambda_mesh_laplacian + + if self.lambda_mesh_normal > 0: + self.normal_consistency_loss_fn = NormalConsistencyLoss() + if self.lambda_mesh_laplacian > 0: + self.laplacian_smooth_loss_fn = LaplacianSmoothLoss() + + # Video + self.test_images = [] + self.test_depths = [] + + def training_step(self, batch, batch_idx): + # experiment iterations ratio + # i.e. what proportion of this experiment have we completed (in terms of iterations) so far? + exp_iter_ratio = self.global_step / self.iters + + # TODO(ahmadki): move to database + if exp_iter_ratio < self.latent_iter_ratio: + ambient_ratio = 1.0 + shading_type = ShadingEnum.NORMAL + as_latent = True + else: + if exp_iter_ratio <= self.albedo_iter_ratio: + ambient_ratio = 1.0 + shading_type = None + else: + # random shading + ambient_ratio = self.min_ambient_ratio + (1.0 - self.min_ambient_ratio) * random.random() + rand = random.random() + if rand >= (1.0 - self.textureless_ratio): + shading_type = ShadingEnum.TEXTURELESS + else: + shading_type = ShadingEnum.LAMBERTIAN + + as_latent = False + + return_normal_image = bool(self.lambda_2d_normal_smooth) + return_normal_perturb = bool(self.lambda_3d_normal_smooth) + return_vertices = bool(self.lambda_mesh_laplacian) + return_faces = bool(self.lambda_mesh_normal) or bool(self.lambda_mesh_laplacian) + return_faces_normals = bool(self.lambda_mesh_normal) + outputs = self( + rays_o=batch['rays_o'], # [B, H, W, 3] + rays_d=batch['rays_d'], # [B, H, W, 3] + mvp=batch['mvp'], # [B, 4, 4] + perturb=True, + ambient_ratio=ambient_ratio, + shading_type=shading_type, + binarize=False, + return_normal_image=return_normal_image, + return_normal_perturb=return_normal_perturb, + return_vertices=return_vertices, + return_faces=return_faces, + return_faces_normals=return_faces_normals, + ) + + if as_latent: + pred_rgb = ( + torch.cat([outputs['image'], outputs['opacity']], dim=-1).permute(0, 3, 1, 2).contiguous() + ) # [B, 4, H, W] + else: + pred_rgb = outputs['image'].permute(0, 3, 1, 2).contiguous() # [B, 3, H, W] + + # TODO(ahmadki): move into guidance + azimuth = batch['azimuth'] + text_z = [self.text_z['uncond']] * azimuth.shape[0] + for b in range(azimuth.shape[0]): + if azimuth[b] >= -90 and azimuth[b] < 90: + if azimuth[b] >= 0: + r = 1 - azimuth[b] / 90 + else: + r = 1 + azimuth[b] / 90 + start_z = self.text_z['front'] + end_z = self.text_z['side'] + else: + if azimuth[b] >= 0: + r = 1 - (azimuth[b] - 90) / 90 + else: + r = 1 + (azimuth[b] + 90) / 90 + start_z = self.text_z['side'] + end_z = self.text_z['back'] + pos_z = r * start_z + (1 - r) * end_z + text_z.append(pos_z) + text_z = torch.cat(text_z, dim=0) + + loss_dict = {} + + # SDS loss + guidance_loss = self.guidance.train_step( + text_z, pred_rgb, as_latent=as_latent, guidance_scale=self.guidance_scale + ) + loss_dict['lambda_sds'] = guidance_loss * self.lambda_sds + + # opacity loss + if self.lambda_opacity > 0 and 'opacity' in outputs: + loss_opacity = (outputs['opacity'] ** 2).mean() + loss_dict['loss_opacity'] = self.lambda_opacity * loss_opacity + + # entropy loss + if self.lambda_entropy > 0 and 'weights' in outputs: + alphas = outputs['weights'].clamp(1e-5, 1 - 1e-5) + loss_entropy = (-alphas * torch.log2(alphas) - (1 - alphas) * torch.log2(1 - alphas)).mean() + lambda_entropy = self.lambda_entropy * min(1, 2 * self.global_step / self.iters) + loss_dict['loss_entropy'] = lambda_entropy * loss_entropy + + if self.lambda_2d_normal_smooth > 0 and 'normal_image' in outputs: + pred_normal = outputs['normal_image'] + loss_smooth = (pred_normal[:, 1:, :, :] - pred_normal[:, :-1, :, :]).square().mean() + ( + pred_normal[:, :, 1:, :] - pred_normal[:, :, :-1, :] + ).square().mean() + loss_dict['loss_smooth'] = self.lambda_2d_normal_smooth * loss_smooth + + # orientation loss + if self.lambda_orientation > 0 and all(key in outputs for key in ['weights', 'normals', 'dirs']): + loss_orientation = ( + outputs['weights'].detach() * (outputs['normals'] * outputs['dirs']).sum(-1).clamp(min=0) ** 2 + ) + loss_orientation = loss_orientation.mean() + loss_dict['loss_orientation'] = self.lambda_orientation * loss_orientation + + if self.lambda_3d_normal_smooth > 0 and all(key in outputs for key in ['normals', 'normal_perturb']): + loss_normal_perturb = (outputs['normal_perturb'] - outputs['normals']).abs().mean() + loss_dict['loss_normal_smooth'] = self.lambda_3d_normal_smooth * loss_normal_perturb + + if self.lambda_mesh_normal > 0 and all(key in outputs for key in ['face_normals', 'faces']): + normal_consistency_loss = self.normal_consistency_loss_fn( + face_normals=outputs['face_normals'], t_pos_idx=outputs['faces'] + ) + loss_dict['normal_consistency_loss'] = self.lambda_mesh_normal * normal_consistency_loss + + if self.lambda_mesh_laplacian > 0 and all(key in outputs for key in ['verts', 'faces']): + laplacian_loss = self.laplacian_smooth_loss_fn(verts=outputs['verts'], faces=outputs['faces']) + loss_dict['laplacian_loss'] = self.lambda_mesh_laplacian * laplacian_loss + + loss = sum(loss_dict.values()) + + self.log_dict(loss_dict, prog_bar=False, rank_zero_only=True) + self.log('loss', loss, prog_bar=True, rank_zero_only=True) + + # TODO(ahmadki): LearningRateMonitor + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, rank_zero_only=True) + + self.log('global_step', self.global_step + 1, prog_bar=True, rank_zero_only=True) + + return loss + + def validation_step(self, batch, batch_idx): + # save image + images, depths = self._shared_predict(batch) + + save_path = os.path.join(self.trainer.log_dir, 'validation') + os.makedirs(save_path, exist_ok=True) + for i, (image, depth) in enumerate(zip(images, depths)): + # Save image + cv2.imwrite( + os.path.join( + save_path, + f'{self.current_epoch:04d}_{self.global_step:04d}_{self.global_rank:04d}_{batch_idx:04d}_{i:04d}_rgb.png', + ), + cv2.cvtColor(image, cv2.COLOR_RGB2BGR), + ) + # Save depth + cv2.imwrite( + os.path.join( + save_path, + f'{self.current_epoch:04d}_{self.global_step:04d}_{self.global_rank:04d}_{batch_idx:04d}_{i:04d}_depth.png', + ), + depth, + ) + + def test_step(self, batch, batch_idx): + # save image + images, depths = self._shared_predict(batch) + self.test_images.append(images) + self.test_depths.append(depths) + + def on_test_epoch_end(self): + save_path = os.path.join(self.trainer.log_dir, 'test') + os.makedirs(save_path, exist_ok=True) + + images = np.concatenate(self.test_images, axis=0) + imageio.mimwrite( + os.path.join(os.path.join(save_path, f'{self.current_epoch:04d}_{self.global_step:04d}_rgb.mp4')), + images, + fps=25, + quality=8, + macro_block_size=1, + ) + + depths = np.concatenate(self.test_depths, axis=0) + imageio.mimwrite( + os.path.join(os.path.join(save_path, f'{self.current_epoch:04d}_{self.global_step:04d}_depth.mp4')), + depths, + fps=25, + quality=8, + macro_block_size=1, + ) + + self.test_images.clear() + self.test_depths.clear() + + def predict_step(self, batch, batch_idx): + return self._shared_predict(self, batch) + + def forward( + self, + rays_o, + rays_d, + mvp, + perturb, + ambient_ratio, + shading_type, + binarize, + return_normal_image, + return_normal_perturb, + return_vertices, + return_faces, + return_faces_normals, + ): + outputs = self.renderer( + rays_o=rays_o, + rays_d=rays_d, + mvp=mvp, + perturb=perturb, + ambient_ratio=ambient_ratio, + shading_type=shading_type, + binarize=binarize, + return_normal_image=return_normal_image, + return_normal_perturb=return_normal_perturb, + return_vertices=return_vertices, + return_faces=return_faces, + return_faces_normals=return_faces_normals, + ) + return outputs + + def _shared_predict(self, data): + outputs = self( + rays_o=data['rays_o'], # [B, H, W, 3] + rays_d=data['rays_d'], # [B, H, W, 3] + mvp=data['mvp'], + perturb=False, + ambient_ratio=data['ambient_ratio'] if 'ambient_ratio' in data else 1.0, # TODO(ahmadki): move to dataset + shading_type=data['shading_type'] if 'shading_type' in data else None, # TODO(ahmadki): move to dataset + binarize=False, + return_normal_image=False, + return_normal_perturb=False, + return_vertices=False, + return_faces=False, + return_faces_normals=False, + ) + + images_np = outputs['image'].detach().cpu().numpy() + images_np = (images_np * 255).astype(np.uint8) + + depths_np = outputs['depth'].detach().cpu().numpy() + depths_np = (depths_np - depths_np.min()) / (np.ptp(depths_np) + 1e-6) + depths_np = (depths_np * 255).astype(np.uint8) + + return images_np, depths_np + + # TODO(ahmadki): rework + def setup_optimization(self): + cfg = self._cfg.optim + optimizer_args = dict(cfg) + optimizer_args.pop('name', None) + + optimizer = optim.get_optimizer(cfg.name) + + optimizer = optimizer(params=self.parameters(), **optimizer_args) + + self._optimizer = optimizer + + def configure_optimizers(self): + self.setup_optimization() + return self._optimizer diff --git a/nemo/collections/multimodal/models/nerf/txt2nerf_base.py b/nemo/collections/multimodal/models/nerf/txt2nerf_base.py new file mode 100644 index 000000000000..dbd6601da138 --- /dev/null +++ b/nemo/collections/multimodal/models/nerf/txt2nerf_base.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.multimodal.models.nerf.base import NerfModelBase + + +class Txt2NerfBase(NerfModelBase): + def __init__(self, cfg): + super().__init__(cfg) + self.prompt = cfg.prompt + self.negative_prompt = cfg.negative_prompt + self.front_prompt = cfg.front_prompt + self.side_prompt = cfg.side_prompt + self.back_prompt = cfg.back_prompt + + self.nerf_cfg = cfg.nerf + self.renderer_cfg = cfg.renderer + self.guidance_cfg = cfg.guidance + + nerf = self.from_config_dict(cfg.nerf) + material = self.from_config_dict(cfg.material) + background = self.from_config_dict(cfg.background) + self.renderer = self.build_renderer(cfg.renderer, nerf, material, background) + self.guidance = None + + def build_renderer(self, cfg, nerf, material, background): + renderer = self.from_config_dict(cfg) + renderer.nerf = nerf + renderer.material = material + renderer.background = background + return renderer + + def build_guidance(self, cfg): + self.guidance = self.from_config_dict(cfg) + self.guidance.eval() + for p in self.guidance.parameters(): + p.requires_grad = False + + def prepare_embeddings(self): + # TODO(ahmadki): add top view ? + self.text_z = { + "default": self.guidance.get_text_embeds([self.prompt]), + "uncond": self.guidance.get_text_embeds([self.negative_prompt]), + "front": self.guidance.get_text_embeds([f"{self.prompt}{self.front_prompt}"]), + "side": self.guidance.get_text_embeds([f"{self.prompt}{self.side_prompt}"]), + "back": self.guidance.get_text_embeds([f"{self.prompt}{self.back_prompt}"]), + } + + def on_fit_start(self) -> None: + self.build_guidance(self.guidance_cfg) + self.prepare_embeddings() + + def on_train_batch_start(self, batch, batch_idx, unused=0): + if self.is_module_updatable(self.guidance): + self.guidance.update_step(epoch=self.current_epoch, global_step=self.global_step) + + if self.is_module_updatable(self.renderer.nerf): + self.renderer.nerf.update_step(epoch=self.current_epoch, global_step=self.global_step) + + if self.is_module_updatable(self.renderer.material): + self.renderer.material.update_step(epoch=self.current_epoch, global_step=self.global_step) + + if self.is_module_updatable(self.renderer.background): + self.renderer.background.update_step(epoch=self.current_epoch, global_step=self.global_step) + + if self.is_module_updatable(self.renderer): + self.renderer.update_step(epoch=self.current_epoch, global_step=self.global_step) + + dataset = self.trainer.train_dataloader.dataset + if self.is_module_updatable(dataset): + dataset.update_step(epoch=self.current_epoch, global_step=self.global_step) + + def mesh(self, resolution, batch_size=128, density_thresh=None): + return self.nerf.mesh(resolution=resolution, batch_size=batch_size, density_thresh=density_thresh) + + def on_save_checkpoint(self, checkpoint): + # remove guidance from checkpoint. + # We can still laod the model without guidance checkpoints because the module is not initalized + # at __init__ time. + keys_to_remove = [key for key in checkpoint['state_dict'].keys() if key.startswith('guidance.')] + for key in keys_to_remove: + del checkpoint['state_dict'][key] diff --git a/nemo/collections/multimodal/models/text_to_image/__init__.py b/nemo/collections/multimodal/models/text_to_image/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/__init__.py b/nemo/collections/multimodal/models/text_to_image/controlnet/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py new file mode 100644 index 000000000000..36329c3b7d0f --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -0,0 +1,1023 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import einops +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from torch._inductor import config as inductor_config + +from nemo.collections.multimodal.data.controlnet.controlnet_dataset import build_train_valid_datasets +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import LatentDiffusion +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.ddim import DDIMSampler +from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel import ( + AttentionBlock, + Downsample, + ResBlock, + TimestepEmbedSequential, + UNetModel, +) +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + conv_nd, + linear, + timestep_embedding, + zero_module, +) +from nemo.collections.multimodal.parts.stable_diffusion.utils import exists, log_txt_as_img +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.utils import logging + +try: + from apex import amp + from apex.transformer.enums import AttnMaskType + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +try: + from torchvision.utils import make_grid + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + + +class ControlledUnetModel(UNetModel): + ''' + Modified Unet class that combines the output of controlling copy and frozen copy during forward pass. + ''' + + def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): + ''' + :param x: latents of diffusion process + :param timesteps: diffusion step + :param context: text embedding guiding the denoising process + :param control: output from controlling copy of each corresponding layer + :param only_mid_control: whether to add the output of controlling copy from middle block only + ''' + hs = [] + with torch.no_grad(): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + h = x.type(emb.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + + if control is not None: + h += control.pop() + + for i, module in enumerate(self.output_blocks): + if only_mid_control or control is None: + h = torch.cat([h, hs.pop()], dim=1) + else: + h = torch.cat([h, hs.pop() + control.pop()], dim=1) + h = module(h, emb, context) + + h = h.type(x.dtype) + return self.out(h) + + +class ControlLDM(LatentDiffusion): + def __init__(self, cfg, model_parallel_config): + super().__init__(cfg=cfg, model_parallel_config=model_parallel_config) + self.control_model = ControlLDM.from_config_dict(cfg.control_stage_config) + self.control_key = cfg.control_key + self.only_mid_control = cfg.only_mid_control + self.control_scales = [1.0] * 13 + self.sd_locked = cfg.sd_locked + self.channels_last = cfg.channels_last + + if cfg.get("inductor", False): + # TorchInductor with CUDA graph can lead to OOM + inductor_config.triton.cudagraphs = cfg.get("inductor_cudagraphs", False) + torch._dynamo.config.dynamic_shapes = False + torch._dynamo.config.automatic_dynamic_shapes = False + self.control_model = torch.compile(self.control_model) + + if self.channels_last: + self.control_model = self.control_model.to(memory_format=torch.channels_last) + + @torch.no_grad() + def get_input(self, batch, k, bs=None, *args, **kwargs): + x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) + control = batch[self.control_key] + if bs is not None: + control = control[:bs] + control = control.to(torch.cuda.current_device()) + if self.channels_last: + control = control.permute(0, 3, 1, 2).to(non_blocking=True) + else: + control = einops.rearrange(control, 'b h w c -> b c h w') + control = control.to(memory_format=torch.contiguous_format).float() + return x, dict(c_crossattn=c, c_concat=control) + + def apply_model(self, x_noisy, t, cond, *args, **kwargs): + assert isinstance(cond, dict) + diffusion_model = self.model.diffusion_model + + # cond_txt = torch.cat(cond['c_crossattn'], 1) ## Has removed this first dim in the get_input function, same for below hint input + cond_txt = cond['c_crossattn'] + + if cond['c_concat'] is None: + eps = diffusion_model( + x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control + ) + else: + control = self.control_model(x=x_noisy, hint=cond['c_concat'], timesteps=t, context=cond_txt) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = diffusion_model( + x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control + ) + return eps + + @torch.no_grad() + def get_unconditional_conditioning(self, N): + return self.get_learned_conditioning([""] * N) + + @torch.no_grad() + def log_images( + self, + batch, + N=4, + n_row=2, + sample=False, + ddim_steps=50, + ddim_eta=0.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=False, + unconditional_guidance_scale=9.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + use_ddim = ddim_steps is not None + + log = dict() + batch = next(batch) + batch['images'] = batch['images'].to(torch.cuda.current_device()) + batch['hint'] = batch['hint'].to(torch.cuda.current_device()) + N = batch['images'].shape[0] + z, c = self.get_input(batch, self.first_stage_key, bs=N) + c_cat, c = c["c_concat"][:N], c["c_crossattn"][:N] + N = min(z.shape[0], N) + n_row = min(z.shape[0], n_row) + log["reconstruction"] = self.decode_first_stage(z) + log["control"] = c_cat * 2.0 - 1.0 + log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + samples, z_denoise_row = self.sample_log( + cond={"c_concat": c_cat, "c_crossattn": c}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N) + uc_cat = c_cat # torch.zeros_like(c_cat) + uc_full = {"c_concat": uc_cat, "c_crossattn": uc_cross} + samples_cfg, _ = self.sample_log( + cond={"c_concat": c_cat, "c_crossattn": c}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + ddim_sampler = DDIMSampler(self) + c, h, w = cond["c_concat"][0].shape + shape = (self.channels, h // 8, w // 8) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) + return samples, intermediates + + def parameters(self): + params = list(self.control_model.parameters()) + if not self.sd_locked: + params += list(self.model.diffusion_model.output_blocks.parameters()) + params += list(self.model.diffusion_model.out.parameters()) + return params + + def low_vram_shift(self, is_diffusing): + if is_diffusing: + self.model = self.model.cuda() + self.control_model = self.control_model.cuda() + self.first_stage_model = self.first_stage_model.cpu() + self.cond_stage_model = self.cond_stage_model.cpu() + else: + self.model = self.model.cpu() + self.control_model = self.control_model.cpu() + self.first_stage_model = self.first_stage_model.cuda() + self.cond_stage_model = self.cond_stage_model.cuda() + + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, ###TODO MMY these are new + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + use_flash_attention=False, + from_pretrained_unet=None, + from_NeMo=True, + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert ( + use_spatial_transformer + ), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self.input_hint_block = TimestepEmbedSequential( + conv_nd(dims, hint_channels, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 16, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 16, 32, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 32, 32, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 32, 96, 3, padding=1, stride=2), + nn.SiLU(), + conv_nd(dims, 96, 96, 3, padding=1), + nn.SiLU(), + conv_nd(dims, 96, 256, 3, padding=1, stride=2), + nn.SiLU(), + zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)), + ) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + if from_pretrained_unet is not None: + self.load_from_unet(from_pretrained_unet=from_pretrained_unet, from_NeMo=from_NeMo) + + def load_from_unet(self, from_pretrained_unet, from_NeMo=True): + if not from_NeMo: + print('loading from other source of unet is experimental! Carefully check if keys are loaded correctly.') + else: + print("Loading unet blocks from sd") + + state_dict = torch.load(from_pretrained_unet, map_location='cpu') + state_dict = state_dict['state_dict'] + model_state_dict = self.state_dict() + + re_state_dict = {} + for key_, value_ in state_dict.items(): + if key_.startswith('model.model.diffusion_model'): + re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_ + if key_.startswith('model.diffusion_model'): + re_state_dict[key_.replace('model.diffusion_model.', '')] = value_ + if key_.startswith('model.model._orig_mod.diffusion_model'): + re_state_dict[key_.replace('model.model._orig_mod.diffusion_model.', '')] = value_ + if key_.startswith('model._orig_mod.diffusion_model'): + re_state_dict[key_.replace('model._orig_mod.diffusion_model.', '')] = value_ + + expected_keys = list(model_state_dict.keys()) + loaded_keys = list(re_state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + if ( + 'input_blocks.1.0.in_layers.2.weight' in loaded_keys + and 'input_blocks.1.0.in_layers.1.weight' in expected_keys + ): + # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following + for key_ in missing_keys: + if key_.startswith('input_blocks') or key_.startswith('middle_block.'): + s = key_.split('.') + idx = int(s[-2]) + new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) + re_state_dict[key_] = re_state_dict[new_key_] + + loaded_keys = list(re_state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + self.load_state_dict(re_state_dict, strict=False) + + if len(missing_keys) > 42: + print( + 'warning: only input hint blocks and zero conv layers are randomly initialized. This message indicates some unet blocks are not loaded correctly.' + ) + print(f'There is {len(missing_keys)} total missing keys') + print("Missing:", missing_keys) + print("Unexpected:", unexpected_keys) + else: + print("sd blocks loaded successfully") + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + guided_hint = self.input_hint_block(hint, emb, context) + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + + +class MegatronControlNet(MegatronBaseModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + # this prevents base constructor from initializing tokenizer + self.tokenizer = None + super().__init__(cfg, trainer=trainer) + + self._validate_trainer() + + # megatron_amp_O2 is not yet supported in diffusion models + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + + self.model = self.model_provider_func() + + self.conditioning_keys = [] + + if self.trainer.precision in ['bf16', 'bf16-mixed']: + self.autocast_dtype = torch.bfloat16 + elif self.trainer.precision in [32, '32', '32-true']: + self.autocast_dtype = torch.float + elif self.trainer.precision in [16, '16', '16-mixed']: + self.autocast_dtype = torch.half + else: + raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + + def get_module_list(self): + if isinstance(self.model, list): + return [model.module if isinstance(model, Float16Module) else model for model in self.model] + elif isinstance(self.model, Float16Module): + return [self.model.module] + else: + return [self.model] + + def model_provider_func(self, pre_process=True, post_process=True): + """Model depends on pipeline paralellism.""" + model = ControlLDM(cfg=self.cfg, model_parallel_config=self.model_parallel_config) + return model + + def forward(self, x, c, *args, **kwargs): + output_tensor = self.model(x, c, *args, **kwargs) + return output_tensor + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): + if self.cfg.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0: + assert self.cfg.scale_factor == 1.0, 'rather not use custom rescaling and std-rescaling simultaneously' + batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) + self.model.on_train_batch_start(batch, batch_idx) + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + tensor_shape = None # Placeholder + + # handle asynchronous grad reduction + no_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + + # pipeline schedules will get these from self.model.config + for module in self.get_module_list(): + module.config.no_sync_func = no_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=None, + micro_batch_size=self.cfg.micro_batch_size, + ) + + # losses_reduced_per_micro_batch is a list of dictionaries + # [{"loss": 0.1}, {"loss": 0.2}, ...] which are from gradient accumulation steps + # only the last stages of the pipeline return losses + loss_dict = {} + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + for key in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.stack(loss_tensors_list) + loss_dict[key] = loss_tensor.mean() + loss_mean = loss_dict["train/loss"] + else: + raise NotImplementedError("Losses of micro batches sizes must be uniform!") + else: + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + return loss_mean, loss_dict + + def training_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + # we zero grads here because we also call backward in the apex fwd/bwd functions + self._optimizer.zero_grad() + + loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + + if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.allreduce_sequence_parallel_gradients() + + if self.with_distributed_adam: + # gradients are reduced internally in distributed optimizer + pass + elif self.megatron_amp_O2: + # # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + # if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): + # # main grads are stored in the MainParamsOptimizer wrapper + # self._optimizer.allreduce_main_grads() + self._optimizer.allreduce_main_grads() + else: + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.precision == [16, '16', '16-mixed']: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log_dict(loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'consumed_samples', + self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + return loss_mean + + def backward(self, *args, **kwargs): + """ LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. + """ + pass + + def optimizer_zero_grad(self, *args, **kwargs): + """ LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + pass + + def _append_sequence_parallel_module_grads(self, module, grads): + """ Helper method for allreduce_sequence_parallel_gradients""" + + for param in module.parameters(): + sequence_parallel_param = getattr(param, 'sequence_parallel', False) + if sequence_parallel_param and param.requires_grad: + if self.megatron_amp_O2: + grad = param.main_grad + else: + grad = param.grad + grads.append(grad.data) + + def get_forward_output_and_loss_func(self): + def process_batch(batch): + """ Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. + """ + # noise_map, condition + batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) + if isinstance(batch[self.cfg.cond_stage_key], torch.Tensor): + # in the case of precached text embeddings, cond_stage is also a tensor + batch[self.cfg.cond_stage_key] = batch[self.cfg.cond_stage_key].cuda(non_blocking=True) + + # SD has more dedicated structure for encoding, so we enable autocasting here as well + with torch.cuda.amp.autocast( + self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + x, c = self.model.get_input(batch, self.cfg.first_stage_key) + + if not isinstance(c, dict): + return [x, c] + + if len(self.conditioning_keys) == 0: + self.conditioning_keys = list(c.keys()) + c_list = [c[key] for key in self.conditioning_keys] + return [x, *c_list] + + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + batch = process_batch(batch) + batch = [x.cuda(non_blocking=True) for x in batch] + if len(self.conditioning_keys) == 0: + x, c = batch + else: + x = batch[0] + c = {} + for idx, key in enumerate(self.conditioning_keys): + c[key] = batch[1 + idx] + loss, loss_dict = model(x, c) + + def dummy(output_tensor): + return loss, loss_dict + + # output_tensor, and a function to convert output_tensor to loss + loss_dict + return loss, dummy + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + def fwd_output_only_func(batch, model): + raise NotImplementedError + + return fwd_output_only_func + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + tensor_shape = None # Placeholder + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), + forward_only=True, + tensor_shape=None, # required by pipeline parallelism + dtype=self.autocast_dtype, + sequence_parallel=self.cfg.get('sequence_parallel', False), + enable_autocast=True, + ) + # only the last stages of the pipeline return losses + val_loss_dict = {} + if losses_reduced_per_micro_batch: + # average loss across micro batches + for key in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.stack(loss_tensors_list) + val_loss_dict[key] = loss_tensor.mean() + + self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def setup(self, stage=None): + """ PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: + stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. + """ + self.model.rng.manual_seed(self.cfg.seed + 100 * parallel_state.get_data_parallel_rank()) + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda(non_blocking=True) + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + + resume_checkpoint_path = self.trainer.ckpt_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.init_global_step = self.trainer.global_step + + # allowing restored models to optionally setup datasets + self.build_train_valid_test_datasets() + + # Batch size need to be provided for webdatset + self._num_micro_batches = get_num_microbatches() + self._micro_batch_size = self.cfg.micro_batch_size + + self.setup_training_data(self.cfg.data) + self.setup_validation_data(self.cfg.data) + self.setup_test_data(self.cfg.data) + + def build_train_valid_test_datasets(self): + logging.info('Building datasets for Stable Diffusion...') + if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): + raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") + + if self.cfg.first_stage_key.endswith("encoded"): + self._train_ds, self._validation_ds = build_train_valid_precached_datasets( + model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + ) + else: + self._train_ds, self._validation_ds = build_train_valid_datasets( + model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0) + ) + self._test_ds = None + + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + if self._validation_ds is not None: + logging.info(f'Length of val dataset: {len(self._validation_ds)}') + if self._test_ds is not None: + logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Finished building datasets for LatentDiffusion.') + return self._train_ds, self._validation_ds, self._test_ds + + def setup_training_data(self, cfg): + if hasattr(self, '_train_ds') and self._train_ds is not None: + consumed_samples = self.compute_consumed_samples(0) + logging.info( + f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + ) + self._train_dl = torch.utils.data.DataLoader( + self._train_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + ) + + def setup_validation_data(self, cfg): + if hasattr(self, '_validation_ds') and self._validation_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + ) + self._validation_dl = torch.utils.data.DataLoader( + self._validation_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=True, + ) + + def setup_test_data(self, cfg): + if hasattr(self, '_test_ds') and self._test_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + self._test_dl = torch.utils.data.DataLoader( + self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + ) + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. + """ + return batch + + def _validate_trainer(self): + """ Certain trainer configurations can break training. + Here we try to catch them and raise an error. + """ + if self.trainer.accumulate_grad_batches > 1: + raise ValueError( + f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' + ) + + @classmethod + def list_available_models(cls): + return None + + def log_images(self, *args, **kwargs): + return self.model.log_images(*args, **kwargs) + + def parameters(self): + if isinstance(self.model, list): + return itertools.chain.from_iterable(module.parameters() for module in self.model) + else: + return self.model.parameters() diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/util.py b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py new file mode 100644 index 000000000000..3d9a7d16b1c3 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import torch +import torchvision +from PIL import Image +from pytorch_lightning import Callback +from pytorch_lightning.utilities.rank_zero import rank_zero_only + + +class ImageLogger(Callback): + def __init__( + self, + batch_frequency=2000, + max_images=4, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None, + ): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "image_log", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step + if ( + self.check_frequency(check_idx) + and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 + and callable(pl_module.log_images) + and self.max_images > 0 + ): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1.0, 1.0) + + self.log_local( + pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx + ) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + return check_idx % self.batch_freq == 0 + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled: + self.log_img(pl_module, batch, batch_idx, split="train") diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/__init__.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py new file mode 100644 index 000000000000..492347f08524 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py @@ -0,0 +1,639 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Any, Optional + +import torch +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from torch._inductor import config as inductor_config + +from nemo.collections.multimodal.data.dreambooth.dreambooth_dataset import DreamBoothDataset +from nemo.collections.multimodal.modules.stable_diffusion.distributions.distributions import ( + DiagonalGaussianDistribution, +) +from nemo.collections.multimodal.parts.utils import randn_like +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingRandomSampler +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes.common import Serialization +from nemo.utils import logging + +try: + from apex import amp + from apex.transformer.enums import AttnMaskType + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def _collate_fn(examples, with_prior_preservation=False): + if with_prior_preservation: + prompts = [[example["instance_prompt"], example["reg_prompt"]] for example in examples] + images = [example["instance_images"] for example in examples] + [example["reg_images"] for example in examples] + else: + prompts = [[example["instance_prompt"]] for example in examples] + images = [example["instance_images"] for example in examples] + + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + + return prompts, images + + +class DreamBooth(torch.nn.Module, Serialization): + def __init__(self, cfg, model_parallel_config): + super().__init__() + self.cfg = cfg + self.config = model_parallel_config + self.with_prior_preservation = self.cfg.with_prior_preservation + self.num_reg_images = self.cfg.data.num_reg_images + self.prior_loss_weight = self.cfg.prior_loss_weight + self.num_images_per_prompt = self.cfg.data.num_images_per_prompt + + self.train_text_encoder = self.cfg.train_text_encoder + self.instantiate_text_encoder(self.cfg.cond_stage_config) + + self.inductor = self.cfg.inductor + self.inductor_cudagraphs = self.cfg.inductor_cudagraphs + + self.instantiate_vae(self.cfg.first_stage_config) + self.instantiate_unet(self.cfg.unet_config) + + self.scale_factor = self.cfg.scale_factor + self.num_timesteps = self.cfg.noise_scheduler.timesteps + self.parameterization = self.cfg.noise_scheduler.parameterization + self.get_noise_scheduler(self.cfg.noise_scheduler) + + self.model_type = None + self.rng = torch.Generator(device=torch.cuda.current_device(),) + + self.use_cached_latents = self.cfg.use_cached_latents + + if self.cfg.channels_last: + self.unet = self.unet.to(memory_format=torch.channels_last) + + def instantiate_unet(self, cfg): + self.unet = DreamBooth.from_config_dict(cfg) + self.unet.train() + if self.inductor: + # TorchInductor with CUDA graph can lead to OOM + inductor_config.triton.cudagraphs = self.inductor_cudagraphs + torch._dynamo.config.dynamic_shapes = False + torch._dynamo.config.automatic_dynamic_shapes = False + self.unet = torch.compile(self.unet) + + def instantiate_vae(self, cfg): + model = DreamBooth.from_config_dict(cfg) + self.vae = model.eval() + self.vae.train = disabled_train + for param in self.vae.parameters(): + param.requires_grad = False + + def instantiate_text_encoder(self, cfg): + model = DreamBooth.from_config_dict(cfg) + if self.train_text_encoder: + self.text_encoder = model.train() + for param in self.text_encoder.parameters(): + param.requires_grad = True + else: + self.text_encoder = model.eval() + self.text_encoder.train = disabled_train + for param in self.text_encoder.parameters(): + param.requires_grad = False + + def get_noise_scheduler(self, cfg): + model = DreamBooth.from_config_dict(cfg) + self.noise_scheduler = model.eval() + + def forward(self, batch): + + x, cond = batch + if self.use_cached_latents: + x = DiagonalGaussianDistribution(x) + latents = x.sample().detach() * self.scale_factor + else: + latents = self.vae.encode(x).sample().detach() + latents = latents * self.scale_factor + + noise = randn_like(latents, generator=self.rng) + t = torch.randint(0, self.num_timesteps, (latents.shape[0],), generator=self.rng, device=latents.device).long() + x_noisy = self.noise_scheduler(x_start=latents, t=t, noise=noise) + + # cond = self.text_encoder([t[0] for t in batch["prompts"]]) + # if self.with_prior_preservation: + # cond_prior = self.text_encoder([t[1] for t in batch["prompts"]]) + # cond = torch.cat([cond, cond_prior], dim=0) + + model_output = self.unet(x_noisy, t, cond) + + if self.parameterization == "x0": + target = latents + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + if self.with_prior_preservation: + model_pred, model_pred_prior = torch.chunk(model_output, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean") + prior_loss = torch.nn.functional.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + loss = loss + prior_loss * self.prior_loss_weight + + else: + loss = torch.nn.functional.mse_loss(target.float(), model_output.float(), reduction="mean") + return loss + + def parameters(self): + params = list(self.unet.parameters()) + if self.train_text_encoder: + # print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.text_encoder.parameters()) + return params + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + pass + + +class MegatronDreamBooth(MegatronBaseModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + # this prevents base constructor from initializing tokenizer + self.tokenizer = None + super().__init__(cfg, trainer=trainer) + + self._validate_trainer() + + # megatron_amp_O2 is not yet supported in diffusion models + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + self.model = self.model_provider_func() + + if self.trainer.precision in ['bf16', 'bf16-mixed']: + self.autocast_dtype = torch.bfloat16 + elif self.trainer.precision in [32, '32', '32-true']: + self.autocast_dtype = torch.float + elif self.trainer.precision in [16, '16', '16-mixed']: + self.autocast_dtype = torch.half + else: + raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + + def get_module_list(self): + if isinstance(self.model, list): + return [model.module if isinstance(model, Float16Module) else model for model in self.model] + elif isinstance(self.model, Float16Module): + return [self.model.module] + else: + return [self.model] + + def model_provider_func(self, pre_process=True, post_process=True): + """Model depends on pipeline paralellism.""" + model = DreamBooth(cfg=self.cfg, model_parallel_config=self.model_parallel_config) + return model + + def forward(self, batch): + output_tensor = self.model(batch) + return output_tensor + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + tensor_shape = None # Placeholder + + # handle asynchronous grad reduction + no_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + + # pipeline schedules will get these from self.model.config + for module in self.get_module_list(): + module.config.no_sync_func = no_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=None, + micro_batch_size=self.cfg.micro_batch_size, + ) + + # losses_reduced_per_micro_batch is a list of dictionaries + # [{"loss": 0.1}, {"loss": 0.2}, ...] which are from gradient accumulation steps + # only the last stages of the pipeline return losses + loss_dict = {} + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + prefix = 'train' + for key in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.stack(loss_tensors_list) + loss_dict[f'{prefix}/{key}'] = loss_tensor.mean() + loss_mean = loss_dict["train/loss"] + else: + raise NotImplementedError("Losses of micro batches sizes must be uniform!") + else: + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + return loss_mean, loss_dict + + def training_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + + # we zero grads here because we also call backward in the apex fwd/bwd functions + self._optimizer.zero_grad() + + loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + + torch.distributed.broadcast(loss_mean, get_last_rank()) + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.allreduce_sequence_parallel_gradients() + + if self.with_distributed_adam: + # gradients are reduced internally in distributed optimizer + pass + elif self.megatron_amp_O2: + # # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + # if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): + # # main grads are stored in the MainParamsOptimizer wrapper + # self._optimizer.allreduce_main_grads() + self._optimizer.allreduce_main_grads() + elif not self.cfg.get('ddp_overlap', True): + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.precision in [16, '16', '16-mixed']: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, prog_bar=True, batch_size=1) + + self.log_dict(loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'consumed_samples', + self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + return loss_mean + + def validation_step(self, dataloader_iter, batch_idx): + loss, val_loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, True) + + self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True, batch_size=1) + + return loss + + def backward(self, *args, **kwargs): + """ LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. + """ + pass + + def optimizer_zero_grad(self, *args, **kwargs): + """ LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + pass + + def _append_sequence_parallel_module_grads(self, module, grads): + """ Helper method for allreduce_sequence_parallel_gradients""" + + for param in module.parameters(): + sequence_parallel_param = getattr(param, 'sequence_parallel', False) + if sequence_parallel_param and param.requires_grad: + if self.megatron_amp_O2: + grad = param.main_grad + else: + grad = param.grad + grads.append(grad.data) + + def get_forward_output_and_loss_func(self): + def process_batch(batch): + # noise_map, condition + prompts, images = batch + # DB has more dedicated structure for encoding, so we enable autocasting here as well + with torch.cuda.amp.autocast( + self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + images = images.cuda(non_blocking=True) + + cond = self.model.text_encoder([t[0] for t in prompts]) + if self.cfg.with_prior_preservation: + cond_prior = self.model.text_encoder([t[1] for t in prompts]) + cond = torch.cat([cond, cond_prior], dim=0) + + return images, cond + + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + batch = process_batch(batch) + batch = [x.cuda(non_blocking=True) for x in batch] + loss = model(batch) + + def dummy(output_tensor): + return loss, {'loss': loss} + + return loss, dummy + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + def fwd_output_only_func(batch, model): + raise NotImplementedError + + return fwd_output_only_func + + def setup(self, stage=None): + """ PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: + stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. + """ + self.model.rng.manual_seed(self.cfg.seed + 100 * parallel_state.get_data_parallel_rank()) + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda(non_blocking=True) + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + + resume_checkpoint_path = self.trainer.ckpt_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.init_global_step = self.trainer.global_step + + # Batch size need to be provided for webdatset + self._num_micro_batches = get_num_microbatches() + self._micro_batch_size = self.cfg.micro_batch_size + + self.setup_training_data(self.cfg.data) + + def setup_training_data(self, cfg): + if self.cfg.with_prior_preservation: + if cfg.regularization_dir is None: + raise ValueError("Regularization images must be provided to train with prior preservation loss") + if cfg.regularization_prompt is None: + raise ValueError("Regularization prompts must be provided to train with prior preservation loss") + + self.train_dataset = DreamBoothDataset( + instance_data_root=cfg.instance_dir, + instance_prompt=cfg.instance_prompt, + with_prior_preservation=self.cfg.with_prior_preservation, + reg_data_root=cfg.regularization_dir if self.cfg.with_prior_preservation else None, + reg_prompt=cfg.regularization_prompt if self.cfg.with_prior_preservation else None, + size=cfg.resolution, + center_crop=cfg.center_crop, + load_cache_latents=self.model.use_cached_latents, + cached_instance_data_root=self.cfg.data.get("cached_instance_dir", None), + cached_reg_data_root=self.cfg.data.get("cached_reg_dir", None) + if self.cfg.with_prior_preservation + else None, + vae=self.model.vae, + text_encoder=self.model.text_encoder, + ) + + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(self.train_dataset), + consumed_samples=self.compute_consumed_samples(0), + micro_batch_size=self.cfg.micro_batch_size, + global_batch_size=self.cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=False, + ) + + self._train_dl = torch.utils.data.DataLoader( + self.train_dataset, + batch_sampler=batch_sampler, + collate_fn=partial(_collate_fn, with_prior_preservation=self.cfg.with_prior_preservation), + num_workers=cfg.num_workers, + pin_memory=True, + persistent_workers=True, + ) + + def setup_validation_data(self, cfg): + pass + + def setup_test_data(self, cfg): + pass + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. + """ + return batch + + def _validate_trainer(self): + """ Certain trainer configurations can break training. + Here we try to catch them and raise an error. + """ + if self.trainer.accumulate_grad_batches > 1: + raise ValueError( + f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' + ) + + @classmethod + def list_available_models(cls): + return None + + def parameters(self): + if isinstance(self.model, list): + return itertools.chain.from_iterable(module.parameters() for module in self.model) + else: + return self.model.parameters() + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + map_location: Any = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + """ + Loads ModelPT from checkpoint, with some maintenance of restoration. + For documentation, please refer to LightningModule.load_from_checkpoin() documentation. + """ + checkpoint = None + try: + cls._set_model_restore_state(is_being_restored=True) + # TODO: replace with proper PTL API + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + hparams["on_gpu"] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + # TODO: can we do this without overriding? + config_kwargs = kwargs.copy() + if 'trainer' in config_kwargs: + config_kwargs.pop('trainer') + cfg.update(config_kwargs) + + # Disable individual unet/vae weights loading otherwise the model will look for these partial ckpts and raise error + if cfg: + if cfg.get('unet_config') and cfg.get('unet_config').get('from_pretrained'): + cfg.unet_config.from_pretrained = None + if cfg.get('first_stage_config') and cfg.get('first_stage_config').get('from_pretrained'): + cfg.first_stage_config.from_pretrained = None + ## Now when we covert ckpt to nemo, let's always get rid of those _orig_mod + if cfg.get('inductor'): + cfg.inductor = False + ## Append some dummy configs that DB didn't support + if not cfg.get('channels_last'): + cfg.channels_last = True + if not cfg.get('capture_cudagraph_iters'): + cfg.capture_cudagraph_iters = -1 + + # compatibility for stable diffusion old checkpoint tweaks + first_key = list(checkpoint['state_dict'].keys())[0] + if first_key == "betas": + # insert "model." into for megatron wrapper + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = "model." + key + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + elif ( + first_key == 'model.text_encoder.transformer.text_model.embeddings.position_ids' + or first_key == 'model.text_encoder.model.language_model.embedding.position_embeddings' + ): + # remap state keys from dreambooth when using HF clip + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('._orig_mod', "") + new_key = new_key.replace('unet', 'model.diffusion_model') + new_key = new_key.replace('vae', 'first_stage_model') + new_key = new_key.replace('text_encoder', 'cond_stage_model') + new_key = new_key.replace('.noise_scheduler', '') + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + # compatibility for inductor in inference + if not cfg.get('inductor', False): + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('._orig_mod', '', 1) + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + if cfg.get('megatron_amp_O2', False): + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('model.', 'model.module.', 1) + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + if 'cfg' in kwargs: + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) + else: + model = ptl_load_state(cls, checkpoint, strict=strict, cfg=cfg, **kwargs) + # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg + + checkpoint = model + + finally: + cls._set_model_restore_state(is_being_restored=False) + return checkpoint diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/util.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/util.py new file mode 100644 index 000000000000..8e31120d47b3 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/util.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, +) +from nemo.collections.multimodal.parts.stable_diffusion.utils import default, exists +from nemo.core.classes.common import Serialization + + +class DiffusionWrapper(torch.nn.Module, Serialization): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + if isinstance(diff_model_config, nn.Module): + self.diffusion_model = diff_model_config + else: + self.diffusion_model = DiffusionWrapper.from_config_dict(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + x_recon = self.apply_step(x_noisy, t, **cond) + return x_recon + + def apply_step(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class sd_noise_scheduler(nn.Module): + def __init__( + self, + parameterization='eps', + v_posterior=0, + given_betas=None, + beta_schedule='linear', + timesteps=1000, + linear_start=0.00085, + linear_end=0.012, + cosine_s=8e-3, + ): + super().__init__() + self.parameterization = parameterization + self.v_posterior = v_posterior + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + 'posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + 'posterior_mean_coef2', to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + def forward(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/__init__.py b/nemo/collections/multimodal/models/text_to_image/imagen/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/imagen/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py new file mode 100644 index 000000000000..90487eac61dc --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py @@ -0,0 +1,598 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from datetime import datetime +from functools import partial +from typing import Any + +import torch +from omegaconf import DictConfig, open_dict +from pytorch_lightning import Trainer + +from nemo.collections.multimodal.data.imagen.imagen_dataset import build_train_valid_datasets +from nemo.collections.multimodal.models.text_to_image.imagen.precond import ContinousDDPMPrecond, EDMPrecond +from nemo.collections.multimodal.modules.imagen.diffusionmodules.nets import EfficientUNetModel, UNetModel +from nemo.collections.multimodal.modules.imagen.encoder.t5encoder import T5Encoder +from nemo.collections.multimodal.modules.imagen.sampler.sampler import DDPMSampler, EDMSampler +from nemo.collections.multimodal.parts.imagen.utils import random_dropout +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes.common import Serialization +from nemo.utils import logging + +try: + from apex import amp + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False + +try: + from apex.contrib.group_norm import GroupNorm + + OPT_GROUP_NORM = True +except Exception: + print('Fused optimized group norm has not been installed.') + OPT_GROUP_NORM = False + +DUMMY_TENSOR = torch.tensor([1.0]) + + +class Imagen(torch.nn.Module, Serialization): + def __init__(self, cfg, model_parallel_config): + super().__init__() + self.cfg = cfg + self.config = model_parallel_config + # Make sure the initialization on different GPUs are the same + self.unet_type = cfg.get('unet_type', 'base') + self.noise_cond_aug = cfg.get('noise_cond_aug', False) + if self.unet_type == 'base': + logging.info('Initializing UNet.') + unet = UNetModel(**cfg.unet, text_embed_dim=cfg.conditioning.embed_dim) + elif self.unet_type == 'sr': + logging.info('Initializing Efficient-UNet.') + unet = EfficientUNetModel( + **cfg.unet, text_embed_dim=cfg.conditioning.embed_dim, noise_cond_aug=self.noise_cond_aug + ) + elif self.unet_type == 'sr-unet': + logging.info('Initializing UNet for SR model.') + unet = UNetModel(**cfg.unet, text_embed_dim=cfg.conditioning.embed_dim, noise_cond_aug=self.noise_cond_aug) + else: + raise NotImplemented(f'{self.unet_type} UNet is not implemented.') + + self.channels_last = cfg.get('channels_last', False) + if self.channels_last: + assert OPT_GROUP_NORM, 'Training in channels last format requires optmized group norm implementation.' + logging.info('Training in torch channels last format.') + unet = unet.to(memory_format=torch.channels_last) + + # Preconditioning + self.preconditioning_type = cfg.get('preconditioning_type', 'DDPM') + if self.preconditioning_type == 'DDPM': + logging.info('Preconditioned with Continous DDPM') + self.model = ContinousDDPMPrecond(unet=unet, **cfg.preconditioning, noise_cond_aug=self.noise_cond_aug) + self.sampler = DDPMSampler(unet_type=self.unet_type, denoiser=self.model.scheduler) + elif self.preconditioning_type == 'EDM': + logging.info('Preconditioned with EDM') + self.model = EDMPrecond(unet=unet, **cfg.preconditioning, noise_cond_aug=self.noise_cond_aug) + self.sampler = EDMSampler(unet_type=self.unet_type) + else: + raise NotImplemented(f'{self.preconditioning_type} preconditioning is not implemented.') + + self.rng = None + self.conditioning = cfg.conditioning + self.text_drop_rate = cfg.conditioning.drop_rate + self.model_type = None + self.image_size = cfg.unet.image_size + + def setup_rng(self): + # We need to set different rng seed for different GPUs/ different runs; + # otherwise, the noise map and time will be exactly the same. + self.rng = torch.Generator(device=torch.cuda.current_device()) + self.rng_seed = int(datetime.now().timestamp()) + self.cfg.seed + parallel_state.get_data_parallel_rank() + logging.info(f'RNG seed set as {self.rng_seed} for rank {parallel_state.get_data_parallel_rank()}') + self.rng.manual_seed(self.rng_seed) + self.model.set_rng(self.rng) + + @property + def unet(self): + return self.model.unet + + def get_text_encoder(self, encoder_path=None): + # TODO Assume using T5 for all + return T5Encoder(max_seq_len=self.conditioning.token_length, encoder_path=encoder_path) + + def forward(self, x_start, text_embed, text_mask, x_lowres=None): + if self.unet_type == 'base': + assert x_lowres[0].item() == DUMMY_TENSOR.item(), 'Base model should have no low-resolution conditioning' + x_lowres = None + else: + assert x_lowres[0].dim() not in [0, 1], 'SR model should have low-resolution conditioning' + + if self.channels_last: + x_start = x_start.to(memory_format=torch.channels_last) + if x_lowres is not None: + x_lowres = x_lowres.to(memory_format=torch.channels_last) + + # Apply random dropout to text embedding + text_embed = random_dropout(text_embed, drop_rate=self.text_drop_rate) + # UNet Forward Pass + low_res_cond = {'x_low_res': x_lowres} if x_lowres is not None else {} + # UNet Forward Pass and compute loss + loss = self.model.compute_loss( + x0=x_start, + text_embed=text_embed, + text_mask=text_mask, + time=None, # Randomly Sample + noise=None, # Randomly Sample + **low_res_cond, + ) + return loss, {'train/loss': loss} + + @torch.no_grad() + def sample_image( + self, + noise_map, + text_encoding, + text_mask, + x_low_res=None, + cond_scale=1.0, + sampling_steps=None, + thresholding_method='dynamic', + ): + return self.sampler( + self.model, noise_map, text_encoding, text_mask, x_low_res, cond_scale, sampling_steps, thresholding_method + ) + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + # only required for pipeline parallelism + pass + + +class MegatronImagen(MegatronBaseModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + with open_dict(cfg): + cfg.hidden_size = cfg.unet.embed_dim + # this prevents base constructor from initializing tokenizer + self.tokenizer = None + super().__init__(cfg, trainer=trainer) + + self._validate_trainer() + # megatron_amp_O2 is not yet supported in diffusion models + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + + self.model = self.model_provider_func() + + if self.trainer.precision in ['bf16', 'bf16-mixed']: + self.autocast_dtype = torch.bfloat16 + elif self.trainer.precision in [32, '32', '32-true']: + self.autocast_dtype = torch.float + elif self.trainer.precision in [16, '16', '16-mixed']: + self.autocast_dtype = torch.half + else: + raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + + self.online_encoding = cfg.conditioning.get("online_encoding", False) + self.text_encoder_path = cfg.conditioning.get("encoder_path", None) + + def get_module_list(self): + if isinstance(self.model, list): + return [model.module if isinstance(model, Float16Module) else model for model in self.model] + elif isinstance(self.model, Float16Module): + return [self.model.module] + else: + return [self.model] + + def model_provider_func(self, pre_process=True, post_process=True): + """Model depends on pipeline paralellism.""" + model = Imagen(cfg=self.cfg, model_parallel_config=self.model_parallel_config) + return model + + def get_forward_output_and_loss_func(self): + def process_batch(batch): + """ Prepares the batch for megatron fwd/bwd functions. + Global batch is a list of micro batches. + """ + # Base model and SR models have slightly different batch input: + # Base model would only require images (64x64), + # while SR models (both SR256 and SR1024) require low-res image (64x64) and + # actual (cropped) image (256x256) + if self.cfg.unet_type == 'base': + x_start = batch['images'] + # Pass in DUMMY_TENSOR because megatron requires each input to be + # tensor (not None) with same batch size (first dim) + x_lowres = DUMMY_TENSOR.repeat(x_start.shape[0]) + elif self.cfg.unet_type == 'sr' or self.cfg.unet_type == 'sr-unet': + x_start = batch['images_256'] + x_lowres = batch['images_64'] + else: + raise NotImplemented(f'Unknown UNet type: {self.cfg.unet_type}') + + if self.cfg.conditioning.get("online_encoding", False): + input_text = batch["raw_text"] + # Encode the text embeddings using text encoder. + with torch.no_grad(): + text_embed, text_mask = self.text_encoder.encode(input_text) + else: + text_conditioning_key = self.cfg.conditioning.out_key + text_embed = batch[f'{text_conditioning_key}_embeddings'] + text_mask = batch[f'{text_conditioning_key}_mask'] + return [x_start, text_embed, text_mask, x_lowres] + + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + batch = process_batch(batch) + batch = [x.cuda(non_blocking=True) for x in batch] + loss, loss_dict = model(*batch) + + def dummy(output_tensor): + return loss, loss_dict + + # output_tensor, and a function to convert output_tensor to loss + loss_dict + return loss, dummy + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + def fwd_output_only_func(batch, model): + raise NotImplementedError + + return fwd_output_only_func + + def build_train_valid_test_datasets(self): + logging.info('Building datasets for Imagen...') + if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): + raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") + self._train_ds, self._validation_ds = build_train_valid_datasets( + model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0) + ) + # We do not have test dataset + self._test_ds = None + + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + if self._validation_ds is not None: + logging.info(f'Length of val dataset: {len(self._validation_ds)}') + if self._test_ds is not None: + logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Finished building datasets for LatentDiffusion.') + return self._train_ds, self._validation_ds, self._test_ds + + def setup_training_data(self, cfg): + if hasattr(self, '_train_ds') and self._train_ds is not None: + consumed_samples = self.compute_consumed_samples(0) + logging.info( + f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + ) + self._train_dl = torch.utils.data.DataLoader( + self._train_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + ) + + def setup_validation_data(self, cfg): + if hasattr(self, '_validation_ds') and self._validation_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + ) + self._validation_dl = torch.utils.data.DataLoader( + self._validation_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=True, + ) + + def setup_test_data(self, cfg): + if hasattr(self, '_test_ds') and self._test_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + self._test_dl = torch.utils.data.DataLoader( + self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + ) + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + tensor_shape = None + + # handle asynchronous grad reduction + no_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + + # pipeline schedules will get these from self.model.config + for module in self.get_module_list(): + module.config.no_sync_func = no_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + + # TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=None, + micro_batch_size=self.cfg.micro_batch_size, + ) + + # losses_reduced_per_micro_batch is a list of dictionaries + # [{"loss": 0.1}, {"loss": 0.2}, ...] which are from gradient accumulation steps + # only the last stages of the pipeline return losses + loss_dict = {} + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + for key in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.stack(loss_tensors_list) + loss_dict[key] = loss_tensor.mean() + loss_mean = loss_dict["train/loss"] + else: + # Get the total loss since micro batches sizes are not uniform + raise NotImplementedError("Losses of micro batches sizes must be uniform!") + else: + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean, loss_dict + + def training_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + + # we zero grads here because we also call backward in the megatron-core fwd/bwd functions + self._optimizer.zero_grad() + + loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + + torch.distributed.broadcast(loss_mean, get_last_rank()) + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.allreduce_sequence_parallel_gradients() + + if self.with_distributed_adam: + # synchronize asynchronous grad reductions + # note: not necessary, but reduces performance degradation + # from multiple simultaneous NCCL calls + self._optimizer._finish_bucket_grad_sync() + elif self.megatron_amp_O2: + # # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + # if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): + # # main grads are stored in the MainParamsOptimizer wrapper + # self._optimizer.allreduce_main_grads() + self._optimizer.allreduce_main_grads() + elif not self.cfg.get('ddp_overlap', True): + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.precision in [16, '16', '16-mixed']: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log_dict(loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'consumed_samples', + self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + return loss_mean + + def backward(self, *args, **kwargs): + """ LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. + """ + pass + + def optimizer_zero_grad(self, *args, **kwargs): + """ LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + pass + + def _append_sequence_parallel_module_grads(self, module, grads): + """ Helper method for allreduce_sequence_parallel_gradients""" + + for param in module.parameters(): + sequence_parallel_param = getattr(param, 'sequence_parallel', False) + if sequence_parallel_param and param.requires_grad: + if self.megatron_amp_O2: + grad = param.main_grad + else: + grad = param.grad + grads.append(grad.data) + + def validation_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ + + loss, val_loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, True) + + self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True, batch_size=1) + return loss + + def setup(self, stage=None): + """ PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: + stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. + """ + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda(non_blocking=True) + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + + resume_checkpoint_path = self.trainer.ckpt_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.init_global_step = self.trainer.global_step + + # allowing restored models to optionally setup datasets + self.build_train_valid_test_datasets() + + # Batch size need to be provided for webdatset + self._num_micro_batches = get_num_microbatches() + self._micro_batch_size = self.cfg.micro_batch_size + + self.setup_training_data(self.cfg.data) + self.setup_validation_data(self.cfg.data) + self.setup_test_data(self.cfg.data) + # Setup RNG seed in model + self.model.setup_rng() + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. + """ + return batch + + def _validate_trainer(self): + """ Certain trainer configurations can break training. + Here we try to catch them and raise an error. + """ + if self.trainer.accumulate_grad_batches > 1: + raise ValueError( + f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' + ) + + @classmethod + def list_available_models(cls): + return None + + def parameters(self): + if isinstance(self.model, list): + return itertools.chain.from_iterable(module.parameters() for module in self.model) + else: + return self.model.parameters() + + def on_save_checkpoint(self, checkpoint) -> None: + if self.online_encoding: + # Removing the weights relating to Text encoder when saving the checkpoints + frozen_weights_keys = [k for k in checkpoint['state_dict'].keys() if k.startswith("text_encoder")] + for k in frozen_weights_keys: + del checkpoint['state_dict'][k] + + def on_load_checkpoint(self, checkpoint) -> None: + # make sure inductor naming is consistent with checkpoint's + inductor_enabled = self.cfg.get('inductor', False) + state_dict = checkpoint['state_dict'] + inductor_checkpoint = False + for k, v, in state_dict.items(): + if '_orig_mod' in k: + inductor_checkpoint = True + break + + if inductor_enabled and not inductor_checkpoint: + # ckpt needs to be converted to inductor-format weights (add .orig_mod) + logging.info('Add .orig_mod to all weight keys.') + new_state_dict = {} + for k, v in state_dict.items(): + idx = k.find('._orig_mod') + new_key = k[:idx] + k[idx + len('._orig_mod') :] + new_state_dict[new_key] = v + checkpoint['state_dict'] = new_state_dict + elif not inductor_enabled and inductor_checkpoint: + # ckpt needs to be converted to non-inductor-format weights (remove .orig_mod) + logging.info('Remove .orig_mod to all weight keys.') + new_state_dict = {} + for k, v in state_dict.items(): + new_key = k.replace("._orig_mod", "") + new_state_dict[new_key] = v + checkpoint['state_dict'] = new_state_dict + super().on_load_checkpoint(checkpoint) + + def on_fit_start(self) -> None: + if self.online_encoding: + # if encoding text online, set up text_encoder here (after loading checkpoints) instead of in __init__. + # This is because text encoder weights are not saved, so the encoder must be loaded after other weights + # are loaded. + logging.info( + f'Setting up pretrained text encoder: {self.text_encoder_path or "download or use cached t5-11b"}' + ) + self.text_encoder = self.model.get_text_encoder(encoder_path=self.text_encoder_path).to( + torch.cuda.current_device() + ) + self.text_encoder.eval() + for param in self.text_encoder.parameters(): + param.requires_grad = False diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py new file mode 100644 index 000000000000..43660c9000a1 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py @@ -0,0 +1,356 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time +from dataclasses import dataclass, field +from typing import Callable, List, Optional, Union + +import torch +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning import Trainer +from torch.cuda.amp import autocast + +from nemo.collections.multimodal.models.text_to_image.imagen.imagen import Imagen, MegatronImagen +from nemo.collections.multimodal.parts.utils import numpy_to_pil, setup_trainer_and_models_for_inference + + +@dataclass +class ImagenCustomizedModelConfig: + base_ckpt: Optional[str] = None + base_cfg: Optional[str] = None + sr256_ckpt: Optional[str] = None + sr256_cfg: Optional[str] = None + sr1024_ckpt: Optional[str] = None + sr1024_cfg: Optional[str] = None + + +@dataclass +class ImagenSamplingConfig: + step: Optional[int] = None + cfg: Optional[float] = 1 + + +@dataclass +class ImagenPipelineConfig: + model_name: Optional[str] = None + run_ema_model: Optional[bool] = True + customized_model: Optional[ImagenCustomizedModelConfig] = None + num_images_per_promt: Optional[int] = 8 + texts: Optional[List[str]] = field(default_factory=lambda: []) + output_path: Optional[str] = 'output/imagen_inference' + record_time: Optional[bool] = False + encoder_path: Optional[str] = None + target_resolution: Optional[int] = 256 + inference_precision: Optional[str] = '32' + thresholding_method: Optional[str] = 'dynamic' + samplings: Optional[List[ImagenSamplingConfig]] = field(default_factory=lambda: list()) + part: Optional[int] = 0 + + +class ImagenPipeline(Callable): + def __init__(self, models: List[Imagen], text_encoder, cfg, device): + self.models = [model.to(device) for model in models] + self.text_encoder = text_encoder.to(device) + self.cfg = cfg + self.device = device + + def _load_model(model_ckpt: str, model_cfg: str, eval_mode: bool = True, trainer: Trainer = None): + assert model_ckpt is not None, 'model ckpt cannot be None' + if model_ckpt.endswith('.nemo'): + model_cfg = MegatronImagen.restore_from(restore_path=model_ckpt, trainer=trainer, return_config=True) + model_cfg.unet.flash_attention = False + model_cfg.micro_batch_size = 1 + model_cfg.global_batch_size = 1 + model = MegatronImagen.restore_from( + restore_path=model_ckpt, override_config_path=model_cfg, trainer=trainer, + ) + elif model_ckpt.endswith('.ckpt'): + model_cfg = OmegaConf.load(model_cfg) + model_cfg.model.unet.flash_attention = False + model_cfg.model.micro_batch_size = 1 + model_cfg.model.global_batch_size = 1 + model = MegatronImagen(cfg=model_cfg.model, trainer=trainer) + checkpoint = torch.load(model_ckpt, map_location=lambda storage, loc: storage) + + # Change weight keys if training using TorchInductor + state_dict = checkpoint['state_dict'] + del_keys = [] + for k, v in state_dict.items(): + if '._orig_mod' in k: + del_keys.append(k) + if len(del_keys) != 0: + print('ckpt was saved with TorchInductor. Renaming weights..') + for k in del_keys: + new_k = k.replace("._orig_mod", "") + state_dict[new_k] = state_dict[k] + del state_dict[k] + model.load_state_dict(state_dict, strict=True) + else: + raise Exception('Invalid ckpt type. Should be either .nemo or .ckpt with cfg') + + model = model.model # We do not need Megatron Instance for inference + model.model.set_inference_mode(True) # Used for adding the least noise for EDM inference for SR model. + if eval_mode: + model.unet.cuda().eval() + return model + + @staticmethod + def _load_customized_model(cfg: ImagenPipelineConfig, trainer=None, megatron_loading=False, megatron_cfg=None): + if megatron_loading: + assert megatron_cfg + + def model_cfg_modifier(model_cfg): + model_cfg.inductor = False + model_cfg.unet.flash_attention = False + model_cfg.micro_batch_size = megatron_cfg.fid.ncaptions_per_batch + model_cfg.global_batch_size = model_cfg.micro_batch_size * megatron_cfg.fid.ntasks_per_node + + trainer, megatron_models = setup_trainer_and_models_for_inference( + MegatronImagen, cfg=megatron_cfg, model_cfg_modifier=model_cfg_modifier + ) + models = [mm.model for mm in megatron_models] + for model in models: + model.cuda().eval() + model.model.set_inference_mode(True) + return models + customized_models = cfg.customized_model + models = [] + print('Load base model.') + model = ImagenPipeline._load_model( + model_ckpt=customized_models.base_ckpt, model_cfg=customized_models.base_cfg, trainer=trainer, + ) + models.append(model) + + if cfg.target_resolution >= 256: + print('Load SR256 model.') + model = ImagenPipeline._load_model( + model_ckpt=customized_models.sr256_ckpt, model_cfg=customized_models.sr256_cfg, trainer=trainer + ) + models.append(model) + + if cfg.target_resolution >= 1024: + print('Load SR1024 model.') + model = ImagenPipeline._load_model( + model_ckpt=customized_models.sr1024_ckpt, model_cfg=customized_models.sr1024_cfg, trainer=trainer + ) + models.append(model) + return models + + @classmethod + def from_pretrained( + cls, cfg: ImagenPipelineConfig, trainer=None, device='cuda', megatron_loading=False, megatron_cfg=None + ): + target_resolution = cfg.target_resolution + assert target_resolution in [64, 256, 1024] + + # Set encoder_path which will be used when inst the model + if cfg.encoder_path is not None: + os.environ['ENCODER_PATH'] = cfg.encoder_path + + assert cfg.model_name is None, 'No predefined model for now' + assert cfg.customized_model is not None, 'Need to provide customized models for inference' + models = ImagenPipeline._load_customized_model(cfg, trainer, megatron_loading, megatron_cfg) + assert len(models) >= 1, 'Need to load at least one model' + if cfg.inference_precision == '16': + print('Running Inference in FP16.') + print('Converting all difussion models to FP16..') + for model in models: + model.half() + + print('Loading text encoder') + text_encoder = models[0].get_text_encoder(encoder_path=cfg.encoder_path) + if cfg.inference_precision == '16': + print('Converting text encoders to FP16..') + text_encoder.half() + return ImagenPipeline(models=models, text_encoder=text_encoder, cfg=cfg, device=device) + + @torch.no_grad() + def get_text_encodings(self, input_text, repeat=1): + # Repeat the inputs so that we generate multiple samples per query + if isinstance(input_text, str): + inp_text_batch = [input_text] + else: + inp_text_batch = input_text + # Encode the text embeddings using text encoder. + text_encodings, text_mask = self.text_encoder.encode(inp_text_batch, device=self.device) + if repeat != 1: + assert len(inp_text_batch) == 1, 'Repeat should only be applied if we feed single text to encoder.' + text_encodings = text_encodings.repeat(repeat, 1, 1) + text_mask = text_mask.repeat(repeat, 1) + return text_encodings, text_mask + + @torch.no_grad() + def __call__( + self, + prompts: Union[str, List[str]] = None, + inference_steps: Union[int, List[int]] = None, + classifier_free_guidance: Union[float, List[float]] = None, + num_images_per_promt: Optional[int] = 0, + thresholding_method: bool = None, + output_type: Optional[str] = 'pil', + seed: Union[int, List[int]] = 2000, + single_batch_mode: bool = False, + output_res: Optional[int] = None, + low_res_input: Optional[torch.Tensor] = None, + ): + if prompts is None: + prompts = OmegaConf.to_object(self.cfg.texts) + if num_images_per_promt == 0: + num_images_per_promt = self.cfg.num_images_per_promt + if thresholding_method is None: + thresholding_method = self.cfg.thresholding_method + device = self.device + inference_precision = self.cfg.inference_precision + assert inference_precision in ['16', '32', 'AMP'], "Inference Precision should be one of ['16', '32', 'AMP']" + print(f'Running inference in {inference_precision} mode.') + amp_enabled = inference_precision == 'AMP' + + # Based on output_res and low_res_input, determine which models to run + if output_res is not None or low_res_input is not None: + models = [] + if output_res is not None: + for model in self.models: + models.append(model) + if model.image_size == output_res: + break + else: + models = self.models + if low_res_input is not None: + print(f'Low-res input shape: {low_res_input.shape}') + low_res_dim = low_res_input.shape[-1] + num_images_per_promt = low_res_input.shape[0] + for idx, model in enumerate(models): + if model.image_size == low_res_dim: + models = models[idx + 1 :] + break + print(f'Running inference on {len(models)} models.') + else: + models = self.models + + if classifier_free_guidance is None: + cfgs = [each.cfg for each in self.cfg.samplings] + cfgs = cfgs[: len(models)] + else: + cfgs = classifier_free_guidance + if isinstance(cfgs, int) or isinstance(cfgs, float): + cfgs = [cfgs] * len(models) + + if inference_steps is None: + steps = [each.step for each in self.cfg.samplings] + steps = steps[: len(models)] + else: + steps = inference_steps + if isinstance(steps, int): + steps = [steps] * len(models) + + assert len(steps) == len(cfgs) == len(models) + + output = [] + all_res_output = [[] for _ in range(len(models))] + if single_batch_mode: + num_images_per_promt = len(prompts) + + throughputs = {'text-encoding': []} + for idx in range(len(models)): + throughputs[f'stage-{idx+1}'] = [] + for prompt in prompts: + if single_batch_mode: + text_input = prompts + else: + text_input = prompt.strip('\n') + print('Input caption: {}'.format(text_input)) + tic = time.perf_counter() + text_encodings, text_mask = self.get_text_encodings( + text_input, repeat=num_images_per_promt if not single_batch_mode else 1 + ) + throughputs['text-encoding'].append(time.perf_counter() - tic) + + # Set seed + noise_maps = [] + if isinstance(seed, int): + # Single seed for the batch + torch.random.manual_seed(seed) + # Generate noise maps + for model in models: + noise_map = torch.randn( + (num_images_per_promt, 3, model.unet.image_size, model.unet.image_size), device=device + ) + noise_map = noise_map.half() if inference_precision == '16' else noise_map + noise_maps.append(noise_map) + elif isinstance(seed, list): + assert len(seed) == num_images_per_promt + for model in models: + noise_map_batch = [] + for single_seed in seed: + torch.random.manual_seed(single_seed) + noise_map_single = torch.randn( + (1, 3, model.unet.image_size, model.unet.image_size), device=device + ) + noise_map_batch.append(noise_map_single) + noise_map_batch = torch.cat(noise_map_batch, dim=0) + noise_map_batch = noise_map_batch.half() if inference_precision == '16' else noise_map_batch + noise_maps.append(noise_map_batch) + else: + raise RuntimeError('Seed type incorrect.') + + x_low_res = low_res_input + all_res = [] + for idx, (model, noise_map, cfg, step) in enumerate(zip(models, noise_maps, cfgs, steps)): + tic = time.perf_counter() + with autocast(enabled=amp_enabled): + generated_images = model.sample_image( + noise_map=noise_map, + text_encoding=text_encodings, + text_mask=text_mask, + x_low_res=x_low_res, + cond_scale=cfg, + sampling_steps=step, + thresholding_method=thresholding_method, + ) + x_low_res = generated_images + all_res.append(generated_images) + throughputs[f'stage-{idx+1}'].append(time.perf_counter() - tic) + # recenter from [-1, 1] to [0, 1] + assert generated_images is not None + generated_images = ((generated_images + 1) / 2).clamp_(0, 1) + all_res = [((each + 1) / 2).clamp_(0, 1) for each in all_res] + output.append(generated_images) + for idx, each in enumerate(all_res): + all_res_output[idx].append(each) + if single_batch_mode: + break + + if output_type == 'torch': + return torch.cat(output, dim=0), [torch.cat(each, dim=0) for each in all_res_output] + output_new = [] + for x_samples_image in output: + # Convert to numpy + x_samples_image = x_samples_image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == 'pil': + x_samples_image = numpy_to_pil(x_samples_image) + output_new.append(x_samples_image) + + all_res_output_new = [[] for each in range(len(models))] + for idx, res_output in enumerate(all_res_output): + for x_samples_image in res_output: + # Convert to numpy + x_samples_image = x_samples_image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == 'pil': + x_samples_image = numpy_to_pil(x_samples_image) + all_res_output_new[idx].append(x_samples_image) + + for item in throughputs: + throughputs[item] = sum(throughputs[item]) / len(throughputs[item]) + + return output_new, all_res_output_new, throughputs diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/precond.py b/nemo/collections/multimodal/models/text_to_image/imagen/precond.py new file mode 100644 index 000000000000..fc3b3ed7d18d --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/imagen/precond.py @@ -0,0 +1,174 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F + +from nemo.collections.multimodal.modules.imagen.sampler.batch_ops import batch_mul +from nemo.collections.multimodal.modules.imagen.sampler.continuous_ddpm import GaussianDiffusionContinuousTimes +from nemo.collections.multimodal.parts.utils import randn_like + + +class PrecondModel(torch.nn.Module): + def __init__(self, unet, loss_type): + super().__init__() + self.unet = unet + self.rng = None + self.inference = False + if loss_type == 'l1': + self.loss_fn = F.l1_loss + elif loss_type == 'l2': + self.loss_fn = F.mse_loss + elif loss_type == 'huber': + self.loss_fn = F.smooth_l1_loss + else: + raise NotImplementedError(f'{loss_type} loss is not supported') + + def set_inference_mode(self, value): + self.inference = value + + def forward(self, **model_kwargs): + return self.unet(**model_kwargs) + + def forward_with_cond_scale(self, *args, text_embed=None, cond_scale=1.0, **kwargs): + logits = self.forward(*args, text_embed=text_embed, **kwargs) + if cond_scale == 1.0: + return logits + null_logits = self.forward(*args, text_embed=torch.zeros_like(text_embed), **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def set_rng(self, generator): + self.rng = generator + + +class ContinousDDPMPrecond(PrecondModel): + def __init__( + self, + unet, + loss_type='l2', + pred_objective='noise', + noise_schedule='cosine', + timesteps=1000, + noise_cond_aug=False, + ): + super().__init__(unet, loss_type) + self.scheduler = GaussianDiffusionContinuousTimes(noise_schedule=noise_schedule, timesteps=timesteps) + self.pred_objective = pred_objective + assert noise_cond_aug == False, 'noise cond aug currently not supported for DDPM' + + def sample_time(self, batch_size, device=None): + return self.scheduler.sample_random_times(batch_size=batch_size, device=device) + + def get_xt(self, x0, t=None, epsilon=None): + if epsilon is None: + epsilon = randn_like(x0, generator=self.rng) + if t is None: + t = self.sample_time(batch_size=x0.shape[0], device=x0.device) + x_noisy, log_snr, alpha, sigma = self.scheduler.q_sample(x_start=x0, t=t, noise=epsilon,) + return x_noisy, t, epsilon + + def forward(self, x, time, text_embed, text_mask, **model_kwargs): + # Convert time to FP32 for calculating time embedding due to FP16 overflow + time = time.float() + time = self.scheduler.get_condition(time) + time = time.type_as(x) + + return self.unet(x=x, time=time, text_embed=text_embed, text_mask=text_mask, **model_kwargs) + + def compute_loss(self, x0, text_embed, text_mask, time=None, noise=None, **model_kwargs): + x_noisy, time, noise = self.get_xt(x0=x0, t=time, epsilon=noise) + pred = self.forward(x_noisy, time, text_embed, text_mask, **model_kwargs) + # Determine target + if self.pred_objective == 'noise': + target = noise + elif self.pred_objective == 'x_start': + target = x0 + else: + raise ValueError(f'unknown objective {self.pred_objective}') + return self.loss_fn(pred, target) + + def set_rng(self, generator): + self.scheduler.rng = generator + self.rng = generator + + +class EDMPrecond(PrecondModel): + def __init__( + self, + unet, # Underlying model. + loss_type='l2', + sigma_data=0.5, # Expected standard deviation of the training data. + p_mean=-1.2, + p_std=1.2, + noise_cond_aug=False, + ): + super().__init__(unet, loss_type) + self.sigma_data = sigma_data + self.p_mean = p_mean + self.p_std = p_std + self.noise_cond_aug = noise_cond_aug + + def forward(self, x, time, text_embed, text_mask, **model_kwargs): + bs = x.shape[0] + assert time.ndim <= 1, 'time should be in shape of either [bs] or scalar' + sigma = time + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() + c_noise = sigma.log() / 4 + + if c_noise.ndim < 1: + c_noise = c_noise.repeat(bs,) + + if self.noise_cond_aug: + # Applying noise conditioning augmentation + assert 'x_low_res' in model_kwargs, 'x_low_res does not exist when attemping to apply noise augmentation' + x_low_res = model_kwargs['x_low_res'] + if self.inference: + batch_size = x_low_res.shape[0] + time_low_res = torch.ones(batch_size, device=x_low_res.device) * 0.002 + x_low_res_noisy, time_low_res = self.get_xt(x0=x_low_res, t=time_low_res, epsilon=None) + else: + x_low_res_noisy, time_low_res = self.get_xt(x0=x_low_res, t=None, epsilon=None) + c_in_noise = 1 / (self.sigma_data ** 2 + time_low_res ** 2).sqrt() + c_noise_noise = time_low_res.log() / 4 + model_kwargs['x_low_res'] = batch_mul(c_in_noise, x_low_res_noisy) + model_kwargs['time_low_res'] = c_noise_noise + + F_x = self.unet(batch_mul(c_in, x), c_noise, text_embed, text_mask, **model_kwargs) + D_x = batch_mul(c_skip, x) + batch_mul(c_out, F_x) + return D_x + + def sample_time(self, batch_size, device=None): + return (torch.randn(batch_size, device=device, generator=self.rng) * self.p_std + self.p_mean).exp() + + def get_xt(self, x0, t=None, epsilon=None): + if epsilon is None: + epsilon = randn_like(x0, generator=self.rng) + assert epsilon.shape == x0.shape + if t is None: + t = self.sample_time(batch_size=x0.shape[0], device=x0.device) + sigma = t + noise = batch_mul(epsilon, sigma) + return x0 + noise, sigma + + def compute_loss(self, x0, text_embed, text_mask, time=None, noise=None, **model_kwargs): + x_noisy, time = self.get_xt(x0=x0, t=None, epsilon=noise) + pred = self.forward(x_noisy, time, text_embed, text_mask, **model_kwargs) + sigma = time + weight = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2).sqrt() + target = x0 + return self.loss_fn(batch_mul(weight, target), batch_mul(weight, pred),) + + def set_rng(self, generator): + self.rng = generator diff --git a/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/__init__.py b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/__init__.py b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/ddpm_edit.py b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/ddpm_edit.py new file mode 100644 index 000000000000..901745f09421 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/instruct_pix2pix/ldm/ddpm_edit.py @@ -0,0 +1,262 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +https://github.com/timothybrooks/instruct-pix2pix/blob/2afcb7e45bd350765f21a58a0c135871e9dc5a78/stable_diffusion/ldm/models/diffusion/ddpm_edit.py +""" + +import torch +from einops import rearrange + +from nemo.collections.multimodal.data.instruct_pix2pix.edit_dataset import EditDataset +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import ( + LatentDiffusion, + MegatronLatentDiffusion, +) +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronPretrainingRandomSampler, + MegatronPretrainingSampler, +) +from nemo.utils import logging + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +class LatentDiffusionEdit(LatentDiffusion): + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + pl_sd = torch.load(path, map_location="cpu") + if "state_dict" in list(pl_sd.keys()): + pl_sd = pl_sd["state_dict"] + sd = {} + + first_key = list(pl_sd.keys())[0] + # State keys of model trained with TorchDynamo changed from + # "model.xxx" to "model._orig_mod.xxx" + for k, v in pl_sd.items(): + new_k = k.replace("._orig_mod", "") + # compatibility for stable diffusion old checkpoint + # remove megatron wrapper prefix + if first_key == "model.betas": + new_k = new_k.lstrip("model.") + sd[new_k] = v + keys = list(sd.keys()) + + # Our model adds additional channels to the first layer to condition on an input image. + # For the first layer, copy existing channel weights and initialize new channel weights to zero. + input_keys = [ + "model.diffusion_model.input_blocks.0.0.weight", + ] + + self_sd = self.state_dict() + for input_key in input_keys: + if input_key not in sd or input_key not in self_sd: + continue + + input_weight = self_sd[input_key] + if input_weight.size() != sd[input_key].size(): + print(f"Manual init: {input_key}") + input_weight.zero_() + input_weight[:, :4, :, :].copy_(sd[input_key]) + ignore_keys.append(input_key) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + uncond=0.05, + ): + x = batch[k] + if bs is not None: + x = x[:bs] + + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = batch[cond_key] + if bs is not None: + xc["c_crossattn"] = xc["c_crossattn"][:bs] + xc["c_concat"] = xc["c_concat"][:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + + null_prompt = self.get_learned_conditioning([""]) + cond["c_crossattn"] = torch.where( + prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach() + ) + cond["c_concat"] = input_mask * self.encode_first_stage((xc["c_concat"].to(x.device))).mode().detach() + + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + +class MegatronLatentDiffusionEdit(MegatronLatentDiffusion): + def model_provider_func(self, pre_process=True, post_process=True): + """Model depends on pipeline paralellism.""" + model = LatentDiffusionEdit(cfg=self.cfg, model_parallel_config=self.model_parallel_config) + return model + + def setup(self, stage=None): + """ PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: + stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. + """ + self.model.rng.manual_seed(self.cfg.seed + 100 * parallel_state.get_data_parallel_rank()) + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda(non_blocking=True) + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + + resume_checkpoint_path = self.trainer.ckpt_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.init_global_step = self.trainer.global_step + + self.build_train_valid_test_datasets() + self.setup_training_data(self.cfg.data) + self.setup_validation_data(self.cfg.data) + self.setup_test_data(self.cfg.data) + + def build_train_valid_test_datasets(self): + # TODO (yuya): set up splits ratio and other params + if self.cfg.data.data_path is not None: + self._train_ds = EditDataset(path=self.cfg.data.data_path, split="train", flip_prob=0.5) + self._validation_ds = EditDataset(path=self.cfg.data.data_path, split="val") + self._test_ds = EditDataset(path=self.cfg.data.data_path, split="test") + + def setup_training_data(self, cfg): + if hasattr(self, '_train_ds') and self._train_ds is not None: + consumed_samples = self.compute_consumed_samples(0) + logging.info( + f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + ) + self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples) + + def setup_validation_data(self, cfg): + if hasattr(self, '_validation_ds') and self._validation_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + ) + drop_last = True + if not self.cfg.get('validation_drop_last', True): + logging.info(f'Drop last in validation dataset is set to False') + drop_last = False + self._validation_dl = self.build_pretraining_data_loader(self._validation_ds, consumed_samples, drop_last) + + def setup_test_data(self, cfg): + if hasattr(self, '_test_ds') and self._test_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + drop_last = True + if not self.cfg.get('validation_drop_last', True): + logging.info(f'Drop last in validation dataset is set to False') + drop_last = False + self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples, drop_last) + + def build_pretraining_data_loader(self, dataset, consumed_samples, drop_last=True): + """Build dataloader given an input dataset.""" + + if dataset is None: + return None + logging.info(f'Building dataloader with consumed samples: {consumed_samples}') + # Megatron sampler + if hasattr(self._cfg.data, 'dataloader_type') and self._cfg.data.dataloader_type is not None: + # TODO (yuya): fix this + if self._cfg.data.dataloader_type == 'single': + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self._cfg.micro_batch_size, + global_batch_size=self._cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=drop_last, + ) + elif self._cfg.data.dataloader_type == 'cyclic': + batch_sampler = MegatronPretrainingRandomSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=self._cfg.micro_batch_size, + global_batch_size=self._cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=drop_last, + ) + else: + raise Exception(f'{self._cfg.dataloader_type} dataloader type is not supported.') + else: + raise ValueError('cfg.data.dataloader_type not found. Must be "single" or "cyclic"') + + # Torch dataloader. + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, num_workers=self._cfg.data.num_workers, pin_memory=True, + ) diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/__init__.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_model.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_model.py new file mode 100644 index 000000000000..45bd2e5afeea --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_model.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from abc import ABC, abstractclassmethod +from typing import Any, Optional + +import torch + +from nemo.core.classes import ModelPT +from nemo.utils import logging + + +class DiffusionModel(ModelPT, ABC): + @abstractclassmethod + def get_conditioning(self, c: Any) -> Any: + """ + Encode conditioning c. + For txt2img use-case, the input conditioning would be the plain text, + and output would be the encoded embedding for the corresponding text; + For img2img use-case, the input conditioning would be the raw image, + and output would be the corresponding image embedding + + Args: + c: conditioning + + Returns: + encoded conditioning + """ + pass + + @abstractclassmethod + def apply_model(self, x_t: torch.Tensor, t: torch.Tensor, c: Optional[torch.Tensor]) -> torch.Tensor: + """ + Apply Diffusion model. + If c is not given, the model acts as an unconditional diffusion model. + For diffusion model that applies on the pixel space, x_t should be in the pixel space; + for diffusion model that applies on the latent space, x_t is in latent space. + + Args: + x_t: noisy input x at timestamp t + t: timestamp + c: conditioning + + Returns: + Predicted result that has the same shape as x_t + """ + + def on_train_start(self) -> None: + super().on_train_start() + self.init_global_step = self.trainer.global_step + + def _extract_consumed_samples_from_ckpt(self, ckpt_path): + try: + init_consumed_samples = int(float(re.findall(r"consumed_samples\=([0-9]+.[0-9]+)", ckpt_path)[0])) + except (ValueError, TypeError, IndexError): + logging.warning("Cannot parse the checkpoint file to get the consumed samples. assume it is zero.") + init_consumed_samples = 0 + + return init_consumed_samples + + def compute_consumed_samples(self, steps_since_resume=0): + consumed_samples = ( + self.init_consumed_samples + + steps_since_resume + * self.trainer.world_size + * self.cfg.micro_batch_size + * self.trainer.accumulate_grad_batches + ) + return int(consumed_samples) diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/__init__.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py new file mode 100644 index 000000000000..d551edaf1bd2 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -0,0 +1,614 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.model import Decoder, Encoder +from nemo.collections.multimodal.modules.stable_diffusion.distributions.distributions import ( + DiagonalGaussianDistribution, +) +from nemo.collections.multimodal.parts.stable_diffusion.utils import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__( + self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_, _, ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size + 16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + predicted_indices=ind, + ) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss( + qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train" + ) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss( + qloss, + x, + xrec, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind, + ) + + discloss, log_dict_disc = self.loss( + qloss, + x, + xrec, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + suffix, + predicted_indices=ind, + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log( + f"val{suffix}/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True + ) + self.log( + f"val{suffix}/aeloss", aeloss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True + ) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor * self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quantize.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr_g, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + {'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}, + {'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__( + self, + ddconfig, + embed_dim, + lossconfig=None, # TODO make it configurable + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + from_pretrained: str = None, + capture_cudagraph_iters=-1, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = torch.nn.Identity() # instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + if from_pretrained is not None: + state_dict = torch.load(from_pretrained) + self._load_pretrained_model(state_dict) + + # CUDA graph captured sub-modules + self.capture_cudagraph_iters = capture_cudagraph_iters + self.stream = torch.cuda.Stream() + self.encoder_iterations = self.decoder_iterations = 0 + self.encoder_graph = torch.cuda.CUDAGraph() # eval + self.decoder_graph = torch.cuda.CUDAGraph() # eval + self.graphed_encoder = self.graphed_decoder = None # train + self.static_x = self.static_moments = None + self.static_z = self.static_dec = None + + def _state_key_mapping(self, state_dict: dict): + import re + + res_dict = {} + key_list = state_dict.keys() + key_str = " ".join(key_list) + up_block_pattern = re.compile('upsamplers') + p1 = re.compile('mid.block_[0-9]') + p2 = re.compile('decoder.up.[0-9]') + up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1) + for key_, val_ in state_dict.items(): + key_ = ( + key_.replace("up_blocks", "up") + .replace("down_blocks", "down") + .replace('resnets', 'block') + .replace('mid_block', 'mid') + .replace("mid.block.", "mid.block_") + .replace('mid.attentions.0.key', 'mid.attn_1.k') + .replace('mid.attentions.0.query', 'mid.attn_1.q') + .replace('mid.attentions.0.value', 'mid.attn_1.v') + .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') + .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out') + .replace('upsamplers.0', 'upsample') + .replace('downsamplers.0', 'downsample') + .replace('conv_shortcut', 'nin_shortcut') + .replace('conv_norm_out', 'norm_out') + ) + + mid_list = re.findall(p1, key_) + if len(mid_list) != 0: + mid_str = mid_list[0] + mid_id = int(mid_str[-1]) + 1 + key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id)) + + up_list = re.findall(p2, key_) + if len(up_list) != 0: + up_str = up_list[0] + up_id = up_blocks_count - 1 - int(up_str[-1]) + key_ = key_.replace(up_str, up_str[:-1] + str(up_id)) + res_dict[key_] = val_ + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict['encoder.mid.attn_1.q.weight'].shape == torch.Size([512, 512]): + for key in [ + 'encoder.mid.attn_1.q.weight', + 'decoder.mid.attn_1.q.weight', + 'encoder.mid.attn_1.v.weight', + 'decoder.mid.attn_1.v.weight', + 'encoder.mid.attn_1.k.weight', + 'decoder.mid.attn_1.k.weight', + 'encoder.mid.attn_1.proj_out.weight', + 'decoder.mid.attn_1.proj_out.weight', + ]: + state_dict[key] = state_dict[key].unsqueeze(2).unsqueeze(3) + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + ''' + Encode input image in pixel space to latent representation. + ''' + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + ''' + Decode latent representation back to pixel space. + ''' + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, reconstructions, posterior, 0, self.global_step, last_layer=self.get_last_layer(), split="val" + ) + + discloss, log_dict_disc = self.loss( + inputs, reconstructions, posterior, 1, self.global_step, last_layer=self.get_last_layer(), split="val" + ) + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()), + lr=lr, + betas=(0.5, 0.9), + ) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py new file mode 100644 index 000000000000..89063f2490cc --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -0,0 +1,2163 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +from functools import partial +from typing import Any, Optional + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +from einops import rearrange, repeat +from lightning_fabric.utilities.cloud_io import _load as pl_load +from omegaconf import DictConfig, open_dict +from pytorch_lightning import Trainer +from pytorch_lightning.core.saving import _load_state as ptl_load_state +from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from torch._inductor import config as inductor_config +from torchvision.utils import make_grid +from tqdm import tqdm + +from nemo.collections.multimodal.data.stable_diffusion.stable_diffusion_dataset import ( + build_train_valid_datasets, + build_train_valid_precached_datasets, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder import ( + AutoencoderKL, + IdentityFirstStage, + VQModelInterface, +) +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.ddim import DDIMSampler +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, + noise_like, +) +from nemo.collections.multimodal.modules.stable_diffusion.distributions.distributions import ( + DiagonalGaussianDistribution, + normal_kl, +) +from nemo.collections.multimodal.parts.stable_diffusion.utils import ( + count_params, + default, + exists, + isimage, + ismap, + log_txt_as_img, + mean_flat, +) +from nemo.collections.multimodal.parts.utils import randn_like +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common.megatron.module import Float16Module +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes.common import Serialization +from nemo.utils import logging + +try: + from apex import amp + from apex.transformer.enums import AttnMaskType + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} + + +def random_dropout(embeddings, drop_rate): + r""" + Function to perform random dropout for embeddings. + When we drop embeddings, we zero them out. + Args: + embeddings (tensor): Input embeddings + drop_rate (float): Rate of dropping the embedding. + """ + nsamples = embeddings.shape[0] + zero_flag = torch.ones(nsamples, 1, 1, device=torch.cuda.current_device()).to(embeddings.dtype) * (1 - drop_rate) + zero_flag = torch.bernoulli(zero_flag).cuda(non_blocking=True) + embeddings = embeddings * zero_flag + return embeddings + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(torch.nn.Module): + def __init__(self, cfg): + super().__init__() + assert cfg.parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = cfg.parameterization + logging.info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = cfg.clip_denoised + self.log_every_t = cfg.log_every_t + self.first_stage_key = cfg.first_stage_key + self.image_size = cfg.image_size # try conv? + self.channels = cfg.channels + self.channels_last = cfg.get("channels_last", False) + self.use_positional_encodings = cfg.use_positional_encodings + self.model = DiffusionWrapper( + cfg.unet_config, + cfg.conditioning_key, + cfg.inductor, + cfg.inductor_cudagraphs, + cfg.get("capture_cudagraph_iters", -1), + ) + self.model_type = None + count_params(self.model, verbose=True) + + self.v_posterior = cfg.v_posterior + self.original_elbo_weight = cfg.original_elbo_weight + self.l_simple_weight = cfg.l_simple_weight + + self.register_schedule( + given_betas=cfg.given_betas, + beta_schedule=cfg.beta_schedule, + timesteps=cfg.timesteps, + linear_start=cfg.linear_start, + linear_end=cfg.linear_end, + cosine_s=cfg.cosine_s, + ) + + self.loss_type = cfg.loss_type + + self.learn_logvar = cfg.learn_logvar + self.logvar = torch.full(fill_value=cfg.logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.rng = torch.Generator(device=torch.cuda.current_device(),) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1.0 - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1.0 - alphas_cumprod_prev) / ( + 1.0 - alphas_cumprod + ) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer( + 'posterior_mean_coef1', to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + ) + self.register_buffer( + 'posterior_mean_coef2', to_torch((1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)) + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like( + self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + pl_sd = torch.load(path, map_location="cpu") + if "state_dict" in list(pl_sd.keys()): + pl_sd = pl_sd["state_dict"] + + sd = {} + first_key = list(pl_sd.keys())[0] + # State keys of model trained with TorchDynamo changed from + # "model.xxx" to "model._orig_mod.xxx" + for k, v in pl_sd.items(): + new_k = k.replace("._orig_mod", "") + # compatibility for stable diffusion old checkpoint + # remove megatron wrapper prefix + if first_key == "model.betas": + new_k = new_k.lstrip("model.") + sd[new_k] = v + + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + logging.info("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = ( + self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + ) + logging.info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + logging.info(f"Missing Keys: {missing}") + if len(unexpected) > 0: + logging.info(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, generator=self.rng, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample( + img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: randn_like(x_start, generator=self.rng)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: randn_like(x_start, generator=self.rng)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), generator=self.rng, device=x.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + if self.channels_last: + x = x.permute(0, 3, 1, 2).to(non_blocking=True) + else: + x = rearrange(x, "b h w c -> b c h w") + x = x.to(memory_format=torch.contiguous_format, non_blocking=True) + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.long() + noise = randn_like(x_start, generator=self.rng) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + +class LatentDiffusion(DDPM, Serialization): + """main class""" + + def __init__(self, cfg, model_parallel_config): + self.config = model_parallel_config + self.num_timesteps_cond = default(cfg.num_timesteps_cond, 1) + self.scale_by_std = cfg.scale_by_std + assert self.num_timesteps_cond <= cfg.timesteps + # for backwards compatibility after implementation of DiffusionWrapper + if cfg.conditioning_key is None: + conditioning_key = 'concat' if cfg.concat_mode else 'crossattn' + else: + conditioning_key = cfg.conditioning_key + if cfg.cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = cfg.ckpt_path + ignore_keys = cfg.ignore_keys + cfg.conditioning_key = conditioning_key + super().__init__(cfg=cfg) + self.precision = cfg.precision + self.concat_mode = cfg.concat_mode + self.cond_stage_trainable = cfg.cond_stage_trainable + self.cond_stage_key = cfg.cond_stage_key + + self.num_downs = 0 + if "ddconfig" in cfg.first_stage_config and "ch_mult" in cfg.first_stage_config.ddconfig: + self.num_downs = len(cfg.first_stage_config.ddconfig.ch_mult) - 1 + if not cfg.scale_by_std: + self.scale_factor = cfg.scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(cfg.scale_factor)) + self.instantiate_first_stage(cfg.first_stage_config) + self.instantiate_cond_stage(cfg.cond_stage_config) + self.cond_stage_forward = cfg.cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + self.text_embedding_dropout_rate = cfg.text_embedding_dropout_rate + self.fused_opt = cfg.fused_opt + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + if self.channels_last: + self.first_stage_model = self.first_stage_model.to(memory_format=torch.channels_last) + self.model = self.model.to(memory_format=torch.channels_last) + + def make_cond_schedule(self,): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): + # only for very first batch + # set rescale weight to 1./std of encodings + logging.info("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1.0 / z.flatten().std()) + logging.info(f"setting self.scale_factor to {self.scale_factor}") + logging.info("### USING STD-RESCALING ###") + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = LatentDiffusion.from_config_dict(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + logging.info("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + logging.info(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = LatentDiffusion.from_config_dict(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = LatentDiffusion.from_config_dict(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd, force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip( + weighting, self.split_input_params["clip_min_weight"], self.split_input_params["clip_max_weight"], + ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip( + L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"], + ) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, + padding=0, + stride=(stride[0] * uf, stride[1] * uf), + ) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict( + kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, + padding=0, + stride=(stride[0] // df, stride[1] // df), + ) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): + if self.first_stage_key.endswith('encoded'): + gaussian_parameters = batch[self.first_stage_key] + encoder_posterior = DiagonalGaussianDistribution(gaussian_parameters) + else: + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['captions', 'coordinates_bbox', 'txt'] or cond_key.endswith("encoded"): + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key) + else: + xc = x + if (not self.cond_stage_trainable or force_c_encode) and (not cond_key.endswith('encoded')): + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + if self.text_embedding_dropout_rate > 0: + assert self.text_embedding_dropout_rate < 1.0 + c = random_dropout(c, drop_rate=self.text_embedding_dropout_rate) + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1.0 / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + logging.info("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + logging.info("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize + ) + for i in range(z.shape[-1]) + ] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1.0 / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + logging.info("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + logging.info("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [ + self.first_stage_model.decode( + z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize + ) + for i in range(z.shape[-1]) + ] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + logging.info("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + logging.info("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), generator=self.rng, device=x.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t] + c = self.q_sample(x_start=c, t=tc, noise=randn_like(c.float(), generator=self.rng)) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + for key in cond: + if not isinstance(cond[key], list): + cond[key] = [cond[key]] + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if ( + self.cond_stage_key in ["image", "LR_image", "segmentation", 'bbox_img'] + and self.model.conditioning_key + ): # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert len(c) == 1 # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert ( + 'original_image_size' in self.split_input_params + ), 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [ + ( + rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h, + ) + for patch_nr in range(z.shape[-1]) + ] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [ + (x_tl, y_tl, rescale_latent * ks[0] / full_img_w, rescale_latent * ks[1] / full_img_h) + for x_tl, y_tl in tl_patch_coordinates + ] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [ + torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None] for bbox in patch_limits + ] # list of length l with tensors of shape (1, 2) + logging.info(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2] + logging.info(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + logging.info(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + logging.info(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + logging.info(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance( + output_list[0], tuple + ) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: randn_like(x_start, generator=self.rng)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + if (self.precision in ['bf16', 'bf16-mixed']) or (self.precision in [16, '16', '16-mixed']): + model_output = model_output.type(torch.float32) + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + self.logvar = self.logvar.cuda(non_blocking=True) + logvar_t = self.logvar[t].cuda(non_blocking=True) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += self.original_elbo_weight * loss_vlb + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, generator=self.rng, device=torch.cuda.current_device()) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=torch.cuda.current_device(), dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=randn_like(cond, generator=self.rng)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, generator=self.rng, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=randn_like(cond, generator=self.rng)) + + img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + ) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def log_images( + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + **kwargs, + ): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, + ) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.long() + noise = randn_like(z_start, generator=self.rng) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if ( + quantize_denoised + and not isinstance(self.first_stage_model, AutoencoderKL) + and not isinstance(self.first_stage_model, IdentityFirstStage) + ): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log( + cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + quantize_denoised=True, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w) + # zeros will be filled in + mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) + x_samples = self.decode_first_stage(samples) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log( + cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask + ) + x_samples = self.decode_first_stage(samples) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising( + c, shape=(self.channels, self.image_size, self.image_size), batch_size=N + ) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def parameters(self): + params = list(self.model.parameters()) + if self.cond_stage_trainable: + logging.info(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + logging.info('Diffusion model optimizing logvar') + params.append(self.logvar) + return params + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1, generator=self.rng).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + # only required for pipeline parallelism + pass + + +class MegatronLatentDiffusion(MegatronBaseModel): + """Megatron LatentDiffusion Model.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer): + if not HAVE_APEX: + raise ImportError( + "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + if not HAVE_MEGATRON_CORE: + raise ImportError( + "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." + ) + + # this prevents base constructor from initializing tokenizer + self.tokenizer = None + super().__init__(cfg, trainer=trainer) + + self._validate_trainer() + + # megatron_amp_O2 is not yet supported in diffusion models + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + + self.model = self.model_provider_func() + + self.conditioning_keys = [] + + if self.trainer.precision in ['bf16', 'bf16-mixed']: + self.autocast_dtype = torch.bfloat16 + elif self.trainer.precision in [32, '32', '32-true']: + self.autocast_dtype = torch.float + elif self.trainer.precision in [16, '16', '16-mixed']: + self.autocast_dtype = torch.half + else: + raise ValueError('precision must be in ["32-true", "16-mixed", "bf16-mixed"]') + + def get_module_list(self): + if isinstance(self.model, list): + return [model.module if isinstance(model, Float16Module) else model for model in self.model] + elif isinstance(self.model, Float16Module): + return [self.model.module] + else: + return [self.model] + + def model_provider_func(self, pre_process=True, post_process=True): + """Model depends on pipeline paralellism.""" + model = LatentDiffusion(cfg=self.cfg, model_parallel_config=self.model_parallel_config) + return model + + def forward(self, x, c, *args, **kwargs): + output_tensor = self.model(x, c, *args, **kwargs) + return output_tensor + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): + if self.cfg.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0: + assert self.cfg.scale_factor == 1.0, 'rather not use custom rescaling and std-rescaling simultaneously' + batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) + self.model.on_train_batch_start(batch, batch_idx) + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + tensor_shape = None # Placeholder + + # handle asynchronous grad reduction + no_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + + # pipeline schedules will get these from self.model.config + for module in self.get_module_list(): + module.config.no_sync_func = no_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=dataloader_iter, + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=None, + micro_batch_size=self.cfg.micro_batch_size, + ) + + # losses_reduced_per_micro_batch is a list of dictionaries + # [{"loss": 0.1}, {"loss": 0.2}, ...] which are from gradient accumulation steps + # only the last stages of the pipeline return losses + loss_dict = {} + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + for key in losses_reduced_per_micro_batch[0]: + loss_tensors_list = [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.stack(loss_tensors_list) + loss_dict[key] = loss_tensor.mean() + loss_mean = loss_dict["val/loss"] if forward_only else loss_dict["train/loss"] + else: + raise NotImplementedError("Losses of micro batches sizes must be uniform!") + else: + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + return loss_mean, loss_dict + + def training_step(self, dataloader_iter, batch_idx): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + """ + + # we zero grads here because we also call backward in the megatron-core fwd/bwd functions + self._optimizer.zero_grad() + + loss_mean, loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, False) + + torch.distributed.broadcast(loss_mean, get_last_rank()) + + # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced + if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.allreduce_sequence_parallel_gradients() + + if self.with_distributed_adam: + # gradients are reduced internally in distributed optimizer + pass + elif self.megatron_amp_O2: + # # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) + # if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): + # # main grads are stored in the MainParamsOptimizer wrapper + # self._optimizer.allreduce_main_grads() + self._optimizer.allreduce_main_grads() + elif not self.cfg.get('ddp_overlap', True): + # async grad allreduce is not currently implemented for O1/autocasting mixed precision training + # so we all-reduce gradients after the pipeline + self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + + if self.cfg.precision in [16, '16', '16-mixed']: + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log_dict(loss_dict, prog_bar=False, logger=True, on_step=True, rank_zero_only=True, batch_size=1) + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step + 1, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'consumed_samples', + self.compute_consumed_samples(self.trainer.global_step + 1 - self.init_global_step), + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + return loss_mean + + def backward(self, *args, **kwargs): + """ LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. + """ + pass + + def optimizer_zero_grad(self, *args, **kwargs): + """ LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + pass + + def _append_sequence_parallel_module_grads(self, module, grads): + """ Helper method for allreduce_sequence_parallel_gradients""" + + for param in module.parameters(): + sequence_parallel_param = getattr(param, 'sequence_parallel', False) + if sequence_parallel_param and param.requires_grad: + if self.megatron_amp_O2: + grad = param.main_grad + else: + grad = param.grad + grads.append(grad.data) + + def get_forward_output_and_loss_func(self): + def process_batch(batch): + """ Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. + """ + # noise_map, condition + batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) + if isinstance(batch[self.cfg.cond_stage_key], torch.Tensor): + # in the case of precached text embeddings, cond_stage is also a tensor + batch[self.cfg.cond_stage_key] = batch[self.cfg.cond_stage_key].cuda(non_blocking=True) + + # SD has more dedicated structure for encoding, so we enable autocasting here as well + with torch.cuda.amp.autocast( + self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + ): + x, c = self.model.get_input(batch, self.cfg.first_stage_key) + + if not isinstance(c, dict): + return [x, c] + + if len(self.conditioning_keys) == 0: + self.conditioning_keys = list(c.keys()) + c_list = [c[key] for key in self.conditioning_keys] + return [x, *c_list] + + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + batch = process_batch(batch) + batch = [x.cuda(non_blocking=True) for x in batch] + if len(self.conditioning_keys) == 0: + x, c = batch + else: + x = batch[0] + c = {} + for idx, key in enumerate(self.conditioning_keys): + c[key] = batch[1 + idx] + loss, loss_dict = model(x, c) + + def dummy(output_tensor): + return loss, loss_dict + + # output_tensor, and a function to convert output_tensor to loss + loss_dict + return loss, dummy + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + def fwd_output_only_func(batch, model): + raise NotImplementedError + + return fwd_output_only_func + + def validation_step(self, dataloader_iter, batch_idx): + loss, val_loss_dict = self.fwd_bwd_step(dataloader_iter, batch_idx, True) + + self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True, batch_size=1) + + return loss + + def setup(self, stage=None): + """ PTL hook that is executed after DDP spawns. + We setup datasets here as megatron datasets require DDP to instantiate. + See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. + Args: + stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. + """ + self.model.rng.manual_seed(self.cfg.seed + 100 * parallel_state.get_data_parallel_rank()) + + # log number of parameters + if isinstance(self.model, list): + num_parameters_on_device = sum( + [sum([p.nelement() for p in model_module.parameters()]) for model_module in self.model] + ) + else: + num_parameters_on_device = sum([p.nelement() for p in self.model.parameters()]) + + # to be summed across data parallel group + total_num_parameters = torch.tensor(num_parameters_on_device).cuda(non_blocking=True) + + torch.distributed.all_reduce(total_num_parameters, group=parallel_state.get_model_parallel_group()) + + logging.info( + f'Pipeline model parallel rank: {parallel_state.get_pipeline_model_parallel_rank()}, ' + f'Tensor model parallel rank: {parallel_state.get_tensor_model_parallel_rank()}, ' + f'Number of model parameters on device: {num_parameters_on_device:.2e}. ' + f'Total number of model parameters: {total_num_parameters:.2e}.' + ) + + resume_checkpoint_path = self.trainer.ckpt_path + if resume_checkpoint_path: + init_consumed_samples = self._extract_consumed_samples_from_ckpt(resume_checkpoint_path) + else: + init_consumed_samples = 0 + self.init_consumed_samples = init_consumed_samples + self.init_global_step = self.trainer.global_step + + # allowing restored models to optionally setup datasets + self.build_train_valid_test_datasets() + + # Batch size need to be provided for webdatset + self._num_micro_batches = get_num_microbatches() + self._micro_batch_size = self.cfg.micro_batch_size + + self.setup_training_data(self.cfg.data) + self.setup_validation_data(self.cfg.data) + self.setup_test_data(self.cfg.data) + + def build_train_valid_test_datasets(self): + logging.info('Building datasets for Stable Diffusion...') + if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): + raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") + + if self.cfg.first_stage_key.endswith("encoded"): + self._train_ds, self._validation_ds = build_train_valid_precached_datasets( + model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + ) + else: + self._train_ds, self._validation_ds = build_train_valid_datasets( + model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0) + ) + self._test_ds = None + + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + if self._validation_ds is not None: + logging.info(f'Length of val dataset: {len(self._validation_ds)}') + if self._test_ds is not None: + logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Finished building datasets for LatentDiffusion.') + return self._train_ds, self._validation_ds, self._test_ds + + def setup_training_data(self, cfg): + if hasattr(self, '_train_ds') and self._train_ds is not None: + consumed_samples = self.compute_consumed_samples(0) + logging.info( + f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + ) + self._train_dl = torch.utils.data.DataLoader( + self._train_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=True, + persistent_workers=True, + ) + + def setup_validation_data(self, cfg): + if hasattr(self, '_validation_ds') and self._validation_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + ) + self._validation_dl = torch.utils.data.DataLoader( + self._validation_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=True, + ) + + def setup_test_data(self, cfg): + if hasattr(self, '_test_ds') and self._test_ds is not None: + consumed_samples = 0 + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + ) + self._test_dl = torch.utils.data.DataLoader( + self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + ) + + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. + """ + return batch + + def _validate_trainer(self): + """ Certain trainer configurations can break training. + Here we try to catch them and raise an error. + """ + if self.trainer.accumulate_grad_batches > 1: + raise ValueError( + f'Gradient accumulation is done within training_step. trainer.accumulate_grad_batches must equal 1' + ) + + @classmethod + def list_available_models(cls): + return None + + def parameters(self): + if isinstance(self.model, list): + return itertools.chain.from_iterable(module.parameters() for module in self.model) + else: + return self.model.parameters() + + def save_to(self, save_path: str): + # Replace .nemo path in config for NeMo CLIP + cfg = self._cfg + if cfg.get('cond_stage_config').get('restore_from_path'): + with open_dict(cfg): + cfg.cond_stage_config.restore_from_path = None + cfg.cond_stage_config.cfg = self.model.cond_stage_model.cfg + self._cfg = cfg + super().save_to(save_path) + + @classmethod + def load_from_checkpoint( + cls, + checkpoint_path: str, + map_location: Any = None, + hparams_file: Optional[str] = None, + strict: bool = True, + **kwargs, + ): + """ + Loads ModelPT from checkpoint, with some maintenance of restoration. + For documentation, please refer to LightningModule.load_from_checkpoin() documentation. + """ + checkpoint = None + try: + cls._set_model_restore_state(is_being_restored=True) + # TODO: replace with proper PTL API + with pl_legacy_patch(): + if map_location is not None: + checkpoint = pl_load(checkpoint_path, map_location=map_location) + else: + checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + + if hparams_file is not None: + extension = hparams_file.split(".")[-1] + if extension.lower() == "csv": + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ("yml", "yaml"): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError(".csv, .yml or .yaml is required for `hparams_file`") + + hparams["on_gpu"] = False + + # overwrite hparams by the given file + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + + # for past checkpoint need to add the new key + if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: + checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} + # override the hparams with values that were passed in + cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].get('cfg', checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]) + # TODO: can we do this without overriding? + config_kwargs = kwargs.copy() + if 'trainer' in config_kwargs: + config_kwargs.pop('trainer') + cfg.update(config_kwargs) + + # Disable individual unet/vae weights loading otherwise the model will look for these partial ckpts and raise error + if cfg: + if cfg.get('unet_config') and cfg.get('unet_config').get('from_pretrained'): + cfg.unet_config.from_pretrained = None + if cfg.get('first_stage_config') and cfg.get('first_stage_config').get('from_pretrained'): + cfg.first_stage_config.from_pretrained = None + ## Now when we covert ckpt to nemo, let's always get rid of those _orig_mod + if cfg.get('inductor'): + cfg.inductor = False + ## Append some dummy configs that DB didn't support + if not cfg.get('channels_last'): + cfg.channels_last = True + if not cfg.get('capture_cudagraph_iters'): + cfg.capture_cudagraph_iters = -1 + + # compatibility for stable diffusion old checkpoint tweaks + first_key = list(checkpoint['state_dict'].keys())[0] + if first_key == "betas": + # insert "model." into for megatron wrapper + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = "model." + key + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + elif ( + first_key == 'model.text_encoder.transformer.text_model.embeddings.position_ids' + or first_key == 'model.text_encoder.model.language_model.embedding.position_embeddings' + ): + # remap state keys from dreambooth when using HF clip + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('._orig_mod', "") + new_key = new_key.replace('unet', 'model.diffusion_model') + new_key = new_key.replace('vae', 'first_stage_model') + new_key = new_key.replace('text_encoder', 'cond_stage_model') + new_key = new_key.replace('.noise_scheduler', '') + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + # compatibility for inductor in inference + if not cfg.get('inductor', False): + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('._orig_mod', '', 1) + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + if cfg.get('megatron_amp_O2', False): + new_state_dict = {} + for key in checkpoint['state_dict'].keys(): + new_key = key.replace('model.', 'model.module.', 1) + new_state_dict[new_key] = checkpoint['state_dict'][key] + checkpoint['state_dict'] = new_state_dict + + if 'cfg' in kwargs: + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) + else: + model = ptl_load_state(cls, checkpoint, strict=strict, cfg=cfg, **kwargs) + # cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg + + checkpoint = model + + finally: + cls._set_model_restore_state(is_being_restored=False) + return checkpoint + + +class DiffusionWrapper(pl.LightningModule, Serialization): + def __init__( + self, + diff_model_config, + conditioning_key, + inductor: bool = False, + inductor_cudagraphs: bool = False, + capture_cudagraph_iters: int = -1, + ): + super().__init__() + self.diffusion_model = DiffusionWrapper.from_config_dict(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + # Fusing VAE and CLIP doesn't give benefit + if inductor: + # TorchInductor with CUDA graph can lead to OOM + torch._dynamo.config.dynamic_shapes = False + torch._dynamo.config.automatic_dynamic_shapes = False + inductor_config.triton.cudagraphs = inductor_cudagraphs + self.diffusion_model = torch.compile(self.diffusion_model) + # CUDA graph + self.capture_cudagraph_iters = capture_cudagraph_iters + self.iterations = 0 + self.graphed_diffusion_model = None + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + if self.iterations == self.capture_cudagraph_iters: + logging.info("Capturing CUDA graph for module: %s", self.diffusion_model.__class__.__name__) + self.graphed_diffusion_model = torch.cuda.make_graphed_callables(self.diffusion_model, (x, t, cc)) + + if 0 <= self.capture_cudagraph_iters <= self.iterations: + out = self.graphed_diffusion_model(x, t, cc) + else: + out = self.diffusion_model(x, t, context=cc) + self.iterations += 1 + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm_config.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm_config.py new file mode 100644 index 000000000000..2f2acb40ed43 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm_config.py @@ -0,0 +1,144 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Any, List, Optional + +from nemo.core.config import modelPT as model_cfg + + +@dataclass +class LDMUnetConfig: + cls: Optional[str] = 'nemo.collections.multimodal.modules.diffusionmodules.openaimodel.UNetModel' + image_size: Optional[int] = 32 # unused + in_channels: Optional[int] = 4 + out_channels: Optional[int] = 4 + model_channels: Optional[int] = 320 + attention_resolutions: Optional[List[int]] = field(default_factory=lambda: [4, 2, 1]) + num_res_blocks: Optional[int] = 2 + channel_mult: Optional[List[int]] = field(default_factory=lambda: [1, 2, 4, 4]) + num_heads: Optional[int] = 8 + use_spatial_transformer: Optional[bool] = True + transformer_depth: Optional[int] = 1 + context_dim: Optional[int] = 768 + use_checkpoint: Optional[bool] = True + legacy: Optional[bool] = False + use_flash_attention: Optional[bool] = False + + +@dataclass +class SchedulerConfig: + cls: Optional[str] = 'nemo.collections.multimodal.parts.lr_scheduler.LambdaLinearScheduler' + warm_up_steps: Optional[List[int]] = field(default_factory=lambda: [10000]) + cycle_lengths: Optional[List[int]] = field( + default_factory=lambda: [10000000000000] + ) # incredibly large number to prevent corner cases + f_start: Optional[List[float]] = field(default_factory=lambda: [1.0e-6]) + f_max: Optional[List[float]] = field(default_factory=lambda: [1.0]) + f_min: Optional[List[float]] = field(default_factory=lambda: [1.0]) + + +@dataclass +class CLIPEmbedderConfig: + cls: Optional[str] = 'nemo.collections.multimodal.modules.encoders.modules.FrozenCLIPEmbedder' + version: Optional[str] = 'openai/clip-vit-large-patch14' + device: Optional[str] = 'cuda' + max_length: Optional[int] = 77 + + +@dataclass +class LDMEncoderConfig: + double_z: Optional[bool] = True + z_channels: Optional[int] = 4 + resolution: Optional[int] = 256 + in_channels: Optional[int] = 3 + out_ch: Optional[int] = 3 + ch: Optional[int] = 128 + ch_mult: Optional[List[int]] = field(default_factory=lambda: [1, 2, 4, 4]) + num_res_blocks: Optional[int] = 2 + attn_resolutions: Optional[List[int]] = field(default_factory=lambda: []) + dropout: Optional[float] = 0.0 + + +@dataclass +class LDMFirstStageConfig: # Autoencoder + cls: Optional[str] = 'nemo.collections.multimodal.models.ldm.autoencoder.AutoencoderKL' + embed_dim: Optional[int] = 4 + monitor: Optional[str] = 'val/rec_loss' + ddconfig: Optional[LDMEncoderConfig] = LDMEncoderConfig() + + +@dataclass +class DDPMDiffusionModelConfig(model_cfg.ModelConfig): + unet_config: Optional[LDMUnetConfig] = LDMUnetConfig() + timesteps: Optional[int] = 1000 + beta_schedule: Optional[str] = 'linear' + loss_type: Optional[str] = 'l2' + ckpt_path: Optional[str] = None + ignore_keys: Optional[List[str]] = field(default_factory=list) + load_only_unet: Optional[bool] = False + monitor: Optional[str] = 'val/loss' + use_ema: Optional[bool] = True + first_stage_key: Optional[str] = 'image' + image_size: Optional[int] = 256 + channels: Optional[int] = 3 + log_every_t: Optional[int] = 100 + clip_denoised: Optional[bool] = True + linear_start: Optional[float] = 1e-4 + linear_end: Optional[float] = 2e-2 + cosine_s: Optional[float] = 8e-3 + given_betas: Optional[float] = None + original_elbo_weight: Optional[float] = 0.0 + v_posterior: Optional[ + float + ] = 0.0 # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight: Optional[float] = 1.0 + conditioning_key: Optional[str] = None + parameterization: Optional[str] = 'eps' # all assuming fixed variance schedules + scheduler_config: Optional[Any] = None + use_positional_encodings: Optional[bool] = False + learn_logvar: Optional[bool] = False + logvar_init: Optional[float] = 0.0 + learning_rate: Optional[float] = 1.0e-04 + + +@dataclass +class LatentDiffusionModelConfig(DDPMDiffusionModelConfig): + # Overrite Default values + linear_start: Optional[float] = 0.00085 + linear_end: Optional[float] = 0.0120 + num_timesteps_cond: Optional[int] = 1 + log_every_t: Optional[int] = 200 + timesteps: Optional[int] = 1000 + first_stage_key: Optional[str] = 'jpg' + cond_stage_key: Optional[str] = 'txt' + image_size: Optional[int] = 64 + channels: Optional[int] = 4 + cond_stage_trainable: Optional[bool] = False + conditioning_key: Optional[str] = 'crossattn' + monitor: Optional[str] = 'val/loss_simple_ema' + scale_factor: Optional[float] = 0.18215 + use_ema: Optional[bool] = False # TODO + unet_config: Optional[LDMUnetConfig] = LDMUnetConfig() + first_stage_config: Optional[LDMFirstStageConfig] = LDMFirstStageConfig() + scheduler_config: Optional[SchedulerConfig] = SchedulerConfig() + # New attributes in additon to DDPMDiffusionModel + concat_mode: Optional[bool] = True + trainable: Optional[bool] = False + cond_stage_config: Optional[CLIPEmbedderConfig] = CLIPEmbedderConfig() + cond_stage_forward: Optional[Any] = None + scale_by_std: Optional[bool] = False + text_embedding_dropout_rate: Optional[float] = 0 + fused_opt: Optional[bool] = False + inductor: Optional[bool] = False + inductor_cudagraphs: Optional[bool] = False diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/__init__.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/__init__.py new file mode 100644 index 000000000000..70256058631d --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +Sampler = Enum('Sampler', ['PLMS', 'DDIM', 'DPM', 'PARA_DDIM']) diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/base_sampler.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/base_sampler.py new file mode 100644 index 000000000000..b890d863428b --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/base_sampler.py @@ -0,0 +1,339 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers import Sampler +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, +) + + +class AbstractBaseSampler(ABC): + def __init__(self, model, sampler, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + assert isinstance(sampler, Sampler), "Sampler should be of ENUM type Sampler" + self.sampler = sampler + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.cuda.current_device()) + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)) + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer("log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))) + self.register_buffer("sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev, ddim_variance = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_variance", ddim_variance) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer("ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps) + + @abstractmethod + def p_sampling_fn(self): + pass + + def dpm_sampling_fn(self): + pass + + def para_ddim_sampling_fn(self): + pass + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + parallelism=8, + tolerance=0.1, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f"Data shape for sampling is {size}, eta {eta}") + + if self.sampler is Sampler.DPM: + return self.dpm_sampling_fn( + shape=shape, + steps=S, + conditioning=conditioning, + unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale, + x_T=x_T, + ) + + if self.sampler is Sampler.PARA_DDIM: + return self.para_ddim_sampling_fn( + cond=conditioning, + batch_size=batch_size, + per_latent_shape=shape, + x_T=x_T, + steps=S, + parallelism=parallelism, + tolerance=tolerance, + temperature=temperature, + noise_dropout=noise_dropout, + quantize_denoised=quantize_x0, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + + samples, intermediates = self.sampling_fn( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def sampling_fn( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, generator=self.model.rng, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + intermediates = {"x_inter": [img], "pred_x0": [img]} + + # TODO: Is this needed + if self.sampler is Sampler.PLMS: + time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) + else: + time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running {self.sampler.name} Sampling with {total_steps} timesteps") + iterator = tqdm(time_range, desc=f"{self.sampler.name} Sampler", total=total_steps) + old_eps = [] + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if self.sampler is Sampler.PLMS: + ts_next = torch.full( + (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, + ) + else: + old_eps = None + ts_next = None + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + outs = self.p_sampling_fn( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, + t_next=ts_next, + ) + img, pred_x0 = outs[0], outs[1] + if self.sampler is Sampler.PLMS: + e_t = outs[2] + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + return img, intermediates + + def _get_model_output( + self, x, t, unconditional_conditioning, unconditional_guidance_scale, score_corrector, c, corrector_kwargs, + ): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + elif isinstance(c, dict): + ### Contolnet conditioning is dict format + model_t = self.model.apply_model(x, t, c) + model_uncond = self.model.apply_model(x, t, unconditional_conditioning) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = e_t_uncond + unconditional_guidance_scale * (model_t - e_t_uncond) + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + return e_t, model_output + + def _get_x_prev_and_pred_x0( + self, + use_original_steps, + b, + index, + device, + x, + t, + model_output, + e_t, + quantize_denoised, + repeat_noise, + temperature, + noise_dropout, + ): + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + ) + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device) + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/ddim.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/ddim.py new file mode 100644 index 000000000000..13f692d27821 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/ddim.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAMPLING ONLY.""" + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers import Sampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.base_sampler import AbstractBaseSampler +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import extract_into_tensor +from nemo.collections.multimodal.parts.utils import randn_like + + +class DDIMSampler(AbstractBaseSampler): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__(model, sampler=Sampler.DDIM, schedule="linear", **kwargs) + + @torch.no_grad() + def p_sampling_fn( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + e_t, model_output = self._get_model_output( + x, t, unconditional_conditioning, unconditional_guidance_scale, score_corrector, c, corrector_kwargs + ) + x_prev, pred_x0 = self._get_x_prev_and_pred_x0( + use_original_steps, + b, + index, + device, + x, + t, + model_output, + e_t, + quantize_denoised, + repeat_noise, + temperature, + noise_dropout, + ) + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = randn_like(x0, generator=self.model.rng) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/dpmsolver.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/dpmsolver.py new file mode 100644 index 000000000000..b1b046a2c5db --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/dpmsolver.py @@ -0,0 +1,493 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch + +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import expand_dims, interpolate_fn + + +class NoiseScheduleVP: + def __init__( + self, schedule="discrete", betas=None, alphas_cumprod=None, continuous_beta_0=0.1, continuous_beta_1=20.0, + ): + """Create a wrapper class for the forward SDE.""" + + if schedule not in ["discrete", "linear", "cosine"]: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule + ) + ) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) + self.schedule = schedule + if schedule == "cosine": + self.T = 0.9946 + else: + self.T = 1.0 + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device), + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + + def t_fn(log_alpha_t): + return ( + torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model.""" + + def get_model_input_time(t_continuous): + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = ( + noise_schedule.marginal_alpha(t_continuous), + noise_schedule.marginal_std(t_continuous), + ) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = ( + noise_schedule.marginal_alpha(t_continuous), + noise_schedule.marginal_std(t_continuous), + ) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPMSolver: + def __init__( + self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.0, + ): + """Construct a DPM-Solver.""" + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = ( + self.noise_schedule.marginal_alpha(t), + self.noise_schedule.marginal_std(t), + ) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling.""" + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = expand_dims(sigma_t / sigma_s, dims) * x - expand_dims(alpha_t * phi_1, dims) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + """ + if solver_type not in ["dpm_solver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ( + ns.marginal_log_mean_coeff(t_prev_0), + ns.marginal_log_mean_coeff(t), + ) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == "dpm_solver": + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1_0 + ) + else: + if solver_type == "dpm_solver": + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ( + ns.marginal_log_mean_coeff(t_prev_0), + ns.marginal_log_mean_coeff(t), + ) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1.0 + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=3, + skip_type="time_uniform", + method="singlestep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpm_solver", + atol=0.0078, + rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + + if method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type, + ) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type, + ) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/k_diffusion.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/k_diffusion.py new file mode 100644 index 000000000000..ac4f8f7ad73d --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/k_diffusion.py @@ -0,0 +1,838 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torchsde +from scipy import integrate +from torch import nn +from torchdiffeq import odeint +from tqdm.auto import tqdm, trange + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device='cpu'): + """Constructs the noise schedule of Karras et al. (2022).""" + ramp = torch.linspace(0, 1, n) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return append_zero(sigmas).to(device) + + +def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'): + """Constructs an exponential noise schedule.""" + sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp() + return append_zero(sigmas) + + +def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1.0, device='cpu'): + """Constructs an polynomial in log sigma noise schedule.""" + ramp = torch.linspace(1, 0, n, device=device) ** rho + sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)) + return append_zero(sigmas) + + +def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'): + """Constructs a continuous VP noise schedule.""" + t = torch.linspace(1, eps_s, n, device=device) + sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1) + return append_zero(sigmas) + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0.0 + sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + +def default_noise_sampler(x): + return lambda sigma, sigma_next: torch.randn_like(x) + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2 ** 63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_euler( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float('inf'), + s_noise=1.0, +): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +@torch.no_grad() +def sample_euler_ancestral( + model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None +): + """Ancestral sampling with Euler method steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_heun( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float('inf'), + s_noise=1.0, +): + """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == 0: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + return x + + +@torch.no_grad() +def sample_dpm_2( + model, + x, + sigmas, + extra_args=None, + callback=None, + disable=None, + s_churn=0.0, + s_tmin=0.0, + s_tmax=float('inf'), + s_noise=1.0, +): + """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.0 + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + dt = sigmas[i + 1] - sigma_hat + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp() + dt_1 = sigma_mid - sigma_hat + dt_2 = sigmas[i + 1] - sigma_hat + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + return x + + +@torch.no_grad() +def sample_dpm_2_ancestral( + model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None +): + """Ancestral sampling with DPM-Solver second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + d = to_d(x, sigmas[i], denoised) + if sigma_down == 0: + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver-2 + sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + dt_1 = sigma_mid - sigmas[i] + dt_2 = sigma_down - sigmas[i] + x_2 = x + d * dt_1 + denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + d_2 = to_d(x_2, sigma_mid, denoised_2) + x = x + d_2 * dt_2 + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +def linear_multistep_coeff(order, t, i, j): + if order - 1 > i: + raise ValueError(f'Order {order} too high for step {i}') + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0] + + +@torch.no_grad() +def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigmas_cpu = sigmas.detach().cpu().numpy() + ds = [] + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + d = to_d(x, sigmas[i], denoised) + ds.append(d) + if len(ds) > order: + ds.pop(0) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + cur_order = min(i + 1, order) + coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + return x + + +@torch.no_grad() +def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4): + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + v = torch.randint_like(x, 2) * 2 - 1 + fevals = 0 + + def ode_fn(sigma, x): + nonlocal fevals + with torch.enable_grad(): + x = x[0].detach().requires_grad_() + denoised = model(x, sigma * s_in, **extra_args) + d = to_d(x, sigma, denoised) + fevals += 1 + grad = torch.autograd.grad((d * v).sum(), x)[0] + d_ll = (v * grad).flatten(1).sum(1) + return d.detach(), d_ll + + x_min = x, x.new_zeros([x.shape[0]]) + t = x.new_tensor([sigma_min, sigma_max]) + sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5') + latent, delta_ll = sol[0][-1], sol[1][-1] + ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1) + return ll_prior + delta_ll, {'fevals': fevals} + + +class PIDStepSizeController: + """A PID controller for ODE adaptive step size control.""" + + def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8): + self.h = h + self.b1 = (pcoeff + icoeff + dcoeff) / order + self.b2 = -(pcoeff + 2 * dcoeff) / order + self.b3 = dcoeff / order + self.accept_safety = accept_safety + self.eps = eps + self.errs = [] + + def limiter(self, x): + return 1 + math.atan(x - 1) + + def propose_step(self, error): + inv_error = 1 / (float(error) + self.eps) + if not self.errs: + self.errs = [inv_error, inv_error, inv_error] + self.errs[0] = inv_error + factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3 + factor = self.limiter(factor) + accept = factor >= self.accept_safety + if accept: + self.errs[2] = self.errs[1] + self.errs[1] = self.errs[0] + self.h *= factor + return accept + + +class DPMSolver(nn.Module): + """DPM-Solver. See https://arxiv.org/abs/2206.00927.""" + + def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None): + super().__init__() + self.model = model + self.extra_args = {} if extra_args is None else extra_args + self.eps_callback = eps_callback + self.info_callback = info_callback + + def t(self, sigma): + return -sigma.log() + + def sigma(self, t): + return t.neg().exp() + + def eps(self, eps_cache, key, x, t, *args, **kwargs): + if key in eps_cache: + return eps_cache[key], eps_cache + sigma = self.sigma(t) * x.new_ones([x.shape[0]]) + eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t) + if self.eps_callback is not None: + self.eps_callback() + return eps, {key: eps, **eps_cache} + + def dpm_solver_1_step(self, x, t, t_next, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + x_1 = x - self.sigma(t_next) * h.expm1() * eps + return x_1, eps_cache + + def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps) + return x_2, eps_cache + + def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None): + eps_cache = {} if eps_cache is None else eps_cache + h = t_next - t + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + s1 = t + r1 * h + s2 = t + r2 * h + u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps + eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1) + u2 = ( + x + - self.sigma(s2) * (r2 * h).expm1() * eps + - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps) + ) + eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2) + x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps) + return x_3, eps_cache + + def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0.0, s_noise=1.0, noise_sampler=None): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if not t_end > t_start and eta: + raise ValueError('eta must be 0 for reverse sampling') + + m = math.floor(nfe / 3) + 1 + ts = torch.linspace(t_start, t_end, m + 1, device=x.device) + + if nfe % 3 == 0: + orders = [3] * (m - 2) + [2, 1] + else: + orders = [3] * (m - 1) + [nfe % 3] + + for i in range(len(orders)): + eps_cache = {} + t, t_next = ts[i], ts[i + 1] + if eta: + sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta) + t_next_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5 + else: + t_next_, su = t_next, 0.0 + + eps, eps_cache = self.eps(eps_cache, 'eps', x, t) + denoised = x - self.sigma(t) * eps + if self.info_callback is not None: + self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised}) + + if orders[i] == 1: + x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache) + elif orders[i] == 2: + x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache) + else: + x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache) + + x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next)) + + return x + + def dpm_solver_adaptive( + self, + x, + t_start, + t_end, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + ): + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + if order not in {2, 3}: + raise ValueError('order should be 2 or 3') + forward = t_end > t_start + if not forward and eta: + raise ValueError('eta must be 0 for reverse sampling') + h_init = abs(h_init) * (1 if forward else -1) + atol = torch.tensor(atol) + rtol = torch.tensor(rtol) + s = t_start + x_prev = x + accept = True + pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety) + info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0} + + while s < t_end - 1e-5 if forward else s > t_end + 1e-5: + eps_cache = {} + t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h) + if eta: + sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta) + t_ = torch.minimum(t_end, self.t(sd)) + su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5 + else: + t_, su = t, 0.0 + + eps, eps_cache = self.eps(eps_cache, 'eps', x, s) + denoised = x - self.sigma(s) * eps + + if order == 2: + x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache) + else: + x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache) + x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache) + delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs())) + error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5 + accept = pid.propose_step(error) + if accept: + x_prev = x_low + x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t)) + s = t + info['n_accept'] += 1 + else: + info['n_reject'] += 1 + info['nfe'] += order + info['steps'] += 1 + + if self.info_callback is not None: + self.info_callback( + { + 'x': x, + 'i': info['steps'] - 1, + 't': s, + 't_up': s, + 'denoised': denoised, + 'error': error, + 'h': pid.h, + **info, + } + ) + + return x, info + + +@torch.no_grad() +def sample_dpm_fast( + model, + x, + sigma_min, + sigma_max, + n, + extra_args=None, + callback=None, + disable=None, + eta=0.0, + s_noise=1.0, + noise_sampler=None, +): + """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(total=n, disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback( + {'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info} + ) + return dpm_solver.dpm_solver_fast( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + n, + eta, + s_noise, + noise_sampler, + ) + + +@torch.no_grad() +def sample_dpm_adaptive( + model, + x, + sigma_min, + sigma_max, + extra_args=None, + callback=None, + disable=None, + order=3, + rtol=0.05, + atol=0.0078, + h_init=0.05, + pcoeff=0.0, + icoeff=1.0, + dcoeff=0.0, + accept_safety=0.81, + eta=0.0, + s_noise=1.0, + noise_sampler=None, + return_info=False, +): + """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') + with tqdm(disable=disable) as pbar: + dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) + if callback is not None: + dpm_solver.info_callback = lambda info: callback( + {'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info} + ) + x, info = dpm_solver.dpm_solver_adaptive( + x, + dpm_solver.t(torch.tensor(sigma_max)), + dpm_solver.t(torch.tensor(sigma_min)), + order, + rtol, + atol, + h_init, + pcoeff, + icoeff, + dcoeff, + accept_safety, + eta, + s_noise, + noise_sampler, + ) + if return_info: + return x, info + return x + + +@torch.no_grad() +def sample_dpmpp_2s_ancestral( + model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None +): + """Ancestral sampling with DPM-Solver++(2S) second-order steps.""" + extra_args = {} if extra_args is None else extra_args + noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigma_down == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigma_down - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++(2S) + t, t_next = t_fn(sigmas[i]), t_fn(sigma_down) + r = 1 / 2 + h = t_next - t + s = t + r * h + x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2 + # Noise addition + if sigmas[i + 1] > 0: + x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + return x + + +@torch.no_grad() +def sample_dpmpp_sde( + model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None, r=1 / 2 +): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max) if noise_sampler is None else noise_sampler + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + return x + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x + + +class DiscreteSchedule(nn.Module): + """A mapping between continuous noise levels (sigmas) and a list of discrete noise + levels.""" + + def __init__(self, sigmas, quantize): + super().__init__() + self.register_buffer('sigmas', sigmas) + self.register_buffer('log_sigmas', sigmas.log()) + self.quantize = quantize + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def get_sigmas(self, n=None): + if n is None: + return append_zero(self.sigmas.flip(0)) + t_max = len(self.sigmas) - 1 + t = torch.linspace(t_max, 0, n, device=self.sigmas.device) + return append_zero(self.t_to_sigma(t)) + + def sigma_to_t(self, sigma, quantize=None): + quantize = self.quantize if quantize is None else quantize + log_sigma = sigma.log() + dists = log_sigma - self.log_sigmas[:, None] + if quantize: + return dists.abs().argmin(dim=0).view(sigma.shape) + low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx] + w = (low - log_sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def t_to_sigma(self, t): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] + return log_sigma.exp() + + +class DiscreteEpsDDPMDenoiser(DiscreteSchedule): + """A wrapper for discrete schedule DDPM models that output eps (the predicted + noise).""" + + def __init__(self, model, quantize=False): + alphas_cumprod = model.alphas_cumprod + super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize) + self.inner_model = model + self.sigma_data = 1.0 + + def get_scalings(self, sigma): + c_out = -sigma + c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 + return c_out, c_in + + def get_eps(self, *args, **kwargs): + return self.inner_model.apply_model(*args, **kwargs) + + def loss(self, input, noise, sigma, **kwargs): + c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + noised_input = input + noise * append_dims(sigma, input.ndim) + eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs) + return (eps - noise).pow(2).flatten(1).mean(1) + + def forward(self, input, sigma, **kwargs): + c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)] + eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) + return input + eps * c_out diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/para_ddim.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/para_ddim.py new file mode 100644 index 000000000000..f389b8eff4ff --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/para_ddim.py @@ -0,0 +1,231 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers import Sampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.base_sampler import AbstractBaseSampler +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import noise_like + + +class ParaDDIMSampler(AbstractBaseSampler): + """ Parallel version of DDIM sampler. Utilizes Parallel Sampling (https://arxiv.org/abs/2305.16317). + It reduces the latency of a model, but the total compute cost is increased. + + The main three parameters that affect the performance of the algorithm are: + Parallelism (int): Defines the maximal size of the window. That many diffusion steps can happen in + parallel. + Tolerance (float): Sets the maximal error tolerance defined as a ratio between drift of the trajectory + and noise. The larger the tolerance the faster the method is. The smaller the tolerance the better + quality output is achieved. + Number of GPUs (int): Number of GPUs utilizing DataParallel parallelism to compute diffusion steps in + parallel. + + Different combination of these parameters values can result in different latency-quality-compute trade-off. + For more details please refer to the Parallel Sampling paper (https://arxiv.org/abs/2305.16317). + """ + + def __init__(self, model, **kwargs): + super().__init__(model, sampler=Sampler.PARA_DDIM, **kwargs) + + @torch.no_grad() + def p_sampling_fn(self): + pass + + @torch.no_grad() + def para_ddim_sampling_fn( + self, + cond: torch.tensor, + batch_size: int, + per_latent_shape: Tuple[int, ...], + x_T: torch.tensor = None, + steps: int = 50, + parallelism: int = 8, + tolerance: float = 0.1, + temperature: float = 0.0, + noise_dropout: float = 0.0, + quantize_denoised: bool = False, + unconditional_guidance_scale: float = 1.0, + unconditional_conditioning: torch.tensor = None, + score_corrector=None, + corrector_kwargs=None, + ): + print( + f"Running {self.sampler.name} with {steps} timesteps, " + f"parallelism={parallelism}, " + f"and tolerance={tolerance}" + ) + + device = self.model.betas.device + size = (batch_size, *per_latent_shape) + x_T = torch.randn(size, generator=self.model.rng, device=device) if x_T is None else x_T + time_range = np.flip(self.ddim_timesteps).copy() # Make a copy to resolve issue with negative strides + + # Processing window of timesteps [window_start, window_end) in parallel + window_start = 0 + window_size = min(parallelism, steps) + window_end = window_size + + # Store the whole trajectory in memory; it will be iteratively improved + latents = torch.stack([x_T] * (steps + 1)) + + # Pre-computing noises to ensure noise is sampled once per diffusion step + noises = torch.zeros_like(latents) + for i in range(steps - 1, -1, -1): + gaussian_noise = torch.randn_like(x_T) + noise = (self.ddim_variance[i] ** 0.5) * gaussian_noise + noises[i] = noise.clone() + + # Store inverse of the variance to avoid division at every iteration + variance = [self.ddim_variance[i] for i in range(steps - 1, -1, -1)] + [0] + inverse_variance = 1.0 / torch.tensor(variance).to(noises.device) + latent_dim = noises[0, 0].numel() + inverse_variance_norm = inverse_variance[:, None] / latent_dim + + scaled_tolerance = tolerance ** 2 + + with tqdm(total=steps) as progress_bar: + while window_start < steps: + window_size = window_end - window_start + + # Prepare the input to the model. Model will perform window_size noise predictions in parallel + window_cond = torch.stack([cond] * window_size) + window_uncond_cond = torch.stack([unconditional_conditioning] * window_size) + window_latents = latents[window_start:window_end] + window_timesteps = torch.tensor(time_range[window_start:window_end], device=device).repeat( + 1, batch_size + ) + + # Reshape (w, b, ...) -> (w * b, ...) + latents_input = window_latents.flatten(0, 1) + timesteps_input = window_timesteps.flatten(0, 1) + cond_input = window_cond.flatten(0, 1) + uncond_cond_input = window_uncond_cond.flatten(0, 1) + + # Model call + e_t, _ = self._get_model_output( + latents_input, + timesteps_input, + uncond_cond_input, + unconditional_guidance_scale, + score_corrector, + cond_input, + corrector_kwargs, + ) + # Reshape back (w * b, ...) -> (w, b, ...) + e_t = e_t.reshape(window_size, batch_size, *per_latent_shape) + + # Perform Picard iteration + window_latents_picard_iteration = self._get_x_prev( + batch_size=batch_size, + steps=steps, + x=window_latents, + e_t=e_t, + temperature=temperature, + noise_dropout=noise_dropout, + quantize_denoised=quantize_denoised, + window_start=window_start, + window_end=window_end, + device=device, + ).reshape(window_latents.shape) + + # Calculate cumulative drift + delta = window_latents_picard_iteration - window_latents + delta_cum = torch.cumsum(delta, dim=0) + block_latents_new = latents[window_start][None,] + delta_cum + + # Calculate the error + error = torch.linalg.norm( + (block_latents_new - latents[window_start + 1 : window_end + 1]).reshape( + window_size, batch_size, -1 + ), + dim=-1, + ).pow(2) + + # Calculate error magnitude + error_magnitude = error * inverse_variance_norm[window_start + 1 : window_end + 1] + # Pad so at least one value exceeds tolerance + error_magnitude = nn.functional.pad(error_magnitude, (0, 0, 0, 1), value=1e9) + error_exceeding = torch.max(error_magnitude > scaled_tolerance, dim=1).values.int() + + # Find how many diffusion steps have error below given threshold tolerance and shift the window + ind = torch.argmax(error_exceeding).item() + new_window_start = window_start + min(1 + ind, window_size) + new_window_end = min(new_window_start + window_size, steps) + + # Update the trajectory + latents[window_start + 1 : window_end + 1] = block_latents_new + latents[window_end : new_window_end + 1] = latents[window_end][ + None, + ] + + progress_bar.update(new_window_start - window_start) + window_start = new_window_start + window_end = new_window_end + + intermediates = {"x_inter": [latents[i] for i in range(steps)]} + return latents[-1], intermediates + + def _get_x_prev( + self, + batch_size: int, + steps: int, + x: torch.tensor, + e_t: torch.tensor, + temperature: float, + noise_dropout: float, + quantize_denoised: bool, + window_start: int, + window_end: int, + device: Any, + ): + alphas = self.ddim_alphas + alphas_prev = self.ddim_alphas_prev + sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas + window_size = window_end - window_start + + def prepare_tensor(x): + x = torch.tensor(x, device=device).flip(dims=[0]) + x = x.unsqueeze(1).repeat(1, batch_size).reshape(window_size, batch_size, 1, 1, 1) + return x + + # Select parameters corresponding to the currently considered timesteps. Note that index_end < index_start, + # because during diffusion the time is reversed (we go from timestep step to 0) + index_start = steps - window_start + index_end = steps - window_end + a_t = prepare_tensor(alphas[index_end:index_start]) + a_prev = prepare_tensor(alphas_prev[index_end:index_start]) + sigma_t = prepare_tensor(sigmas[index_end:index_start]) + sqrt_one_minus_at = prepare_tensor(sqrt_one_minus_alphas[index_end:index_start]) + + # Current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + # Direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t + + noise = sigma_t * noise_like(x.shape, device) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/plms.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/plms.py new file mode 100644 index 000000000000..2a721d1f9ae0 --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/plms.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAMPLING ONLY.""" + +import torch + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers import Sampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.base_sampler import AbstractBaseSampler + + +class PLMSSampler(AbstractBaseSampler): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__(model, sampler=Sampler.PLMS, schedule="linear", **kwargs) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=False): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + super().make_schedule(ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=False) + + @torch.no_grad() + def p_sampling_fn( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + old_eps=None, + t_next=None, + ): + b, *_, device = *x.shape, x.device + e_t, model_output = self._get_model_output( + x, t, unconditional_conditioning, unconditional_guidance_scale, score_corrector, c, corrector_kwargs + ) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = self._get_x_prev_and_pred_x0( + use_original_steps, + b, + index, + device, + x, + t, + model_output, + e_t, + quantize_denoised, + repeat_noise, + temperature, + noise_dropout, + ) + e_t_next, model_output = self._get_model_output( + x_prev, + t_next, + unconditional_conditioning, + unconditional_guidance_scale, + score_corrector, + c, + corrector_kwargs, + ) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = self._get_x_prev_and_pred_x0( + use_original_steps, + b, + index, + device, + x, + t, + model_output, + e_t_prime, + quantize_denoised, + repeat_noise, + temperature, + noise_dropout, + ) + + return x_prev, pred_x0, e_t diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/sampler_dpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/sampler_dpm.py new file mode 100644 index 000000000000..98a1b69b5b3b --- /dev/null +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/samplers/sampler_dpm.py @@ -0,0 +1,76 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""SAMPLING ONLY.""" + +import torch + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers import Sampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.base_sampler import AbstractBaseSampler +from .dpmsolver import DPMSolver, NoiseScheduleVP, model_wrapper + +MODEL_TYPES = {"eps": "noise", "v": "v"} + + +class DPMSolverSampler(AbstractBaseSampler): + def __init__(self, model, **kwargs): + + super().__init__(model, sampler=Sampler.DPM, **kwargs) + + def to_torch(x, model): + x_copy = x.clone() + x_detached = x_copy.detach() + x_float32 = x_detached.to(torch.float32) + x_device = x_float32.to(model.betas.device) + return x_device + + self.register_buffer("alphas_cumprod", to_torch(model.alphas_cumprod, model)) + + @torch.no_grad() + def p_sampling_fn(self): + pass + + @torch.no_grad() + def dpm_sampling_fn( + self, + shape, + steps, + conditioning=None, + unconditional_conditioning=None, + unconditional_guidance_scale=1.0, + x_T=None, + ): + + device = self.model.betas.device + if x_T is None: + img = torch.randn(shape, generator=self.model.rng, device=device) + else: + img = x_T + + ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + dpm_solver = DPMSolver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample( + img, steps=steps, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True, + ) + + return x.to(device), None diff --git a/nemo/collections/multimodal/modules/imagen/__init__.py b/nemo/collections/multimodal/modules/imagen/__init__.py new file mode 100644 index 000000000000..aee951313044 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Speech Computer Vision collection" diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/__init__.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/__init__.py new file mode 100644 index 000000000000..aee951313044 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Speech Computer Vision collection" diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py new file mode 100644 index 000000000000..de301e0bc038 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention.py @@ -0,0 +1,317 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from: +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py +""" +import math + +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +# Stable attention +class StableAttentionOp(torch.autograd.Function): + # This function defines the attention weight computation in a stable way + # The idea is to scale the gradients of weight matrix by the maximum absolute value. + # In case of overflow, this will prevent weight gradients from exploding. + # In case of underflow, since we clipped the scale to 1e-4, this will prevent underflow. + + @staticmethod + def forward(ctx, q, k): + w = torch.einsum('ncq,nck->nqk', q, k / math.sqrt(k.shape[1])).softmax(dim=2) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + + s = dw.detach().norm(float('inf'), dim=[1, 2], keepdim=True).clip(min=1e-4) + dw = dw / s + + # Due to softmax, w is fp32, making db fp32. + # Type casting is required for amp to work. + db = torch._softmax_backward_data(grad_output=dw, output=w, dim=2, input_dtype=dw.dtype).to(q.dtype) + s = s / math.sqrt(k.shape[1]) + + dq = torch.einsum('nck,nqk->ncq', k, db) * s + dk = torch.einsum('ncq,nqk->nck', q, db) * s + + return dq, dk + + +class QKVStableAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + + # Reshaping q and k + # try: + # q = q.view(bs * self.n_heads, ch, length) + # k = k.view(bs * self.n_heads, ch, length) + # except Exception: + q = q.reshape(bs * self.n_heads, ch, length) + k = k.reshape(bs * self.n_heads, ch, length) + + weight = StableAttentionOp.apply(q, k) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class StableMaskedAttentionOp(torch.autograd.Function): + # Robust attention operation in case of masked attention + @staticmethod + @custom_fwd + def forward(ctx, q, k, mask): + max_neg_value = -float('inf') + w = torch.einsum('ncq,nck->nqk', q, k / math.sqrt(k.shape[1])) + w = w.masked_fill(mask, max_neg_value) + w = w.softmax(dim=2) + + # When we use an arbitrary mask, there is a possibility that we get nans in softmax. + # In this case, use nan_to_num to make it a stable number. + w = w.nan_to_num_() + ctx.save_for_backward(q, k, w, mask) + return w + + @staticmethod + @custom_bwd + def backward(ctx, dw): + q, k, w, mask = ctx.saved_tensors + max_neg_value = -torch.finfo(q.dtype).max + s = dw.detach().norm(float('inf'), dim=[1, 2], keepdim=True).clip(min=1e-4) + dw = dw / s + db = torch._softmax_backward_data(grad_output=dw, output=w, dim=2, input_dtype=dw.dtype) + + # Masking db + db_in = db.clone().masked_fill_(mask, 0) + + s = s / math.sqrt(k.shape[1]) + dq = torch.einsum('nck,nqk->ncq', k, db_in) * s + dk = torch.einsum('ncq,nqk->nck', q, db_in) * s + + # These are dummy derivatives since mask is a constant + dmask = (max_neg_value - w) * db.clone() * s + + return dq, dk, dmask + + +class QKVMaskedAttention(nn.Module): + """ + A module which performs QKV attention. + Attention mask is accepted as input. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, q, k, v, mask): + r""" + Apply QKV attention with attention mask. + + Args: + q: an [N x d x n_seq1] of queries. + k: an [N x d x n_seq2] of keys. + v: an [N x d x n_seq2] of values. + mask: Attention mask of size N x n_seq1 x n_seq2 + + Returns: an [N x d x n_seq1] tensor after attention. + """ + + bs, width, length_q = q.shape + _, _, length_k = k.shape + + assert width % self.n_heads == 0 + ch = width // self.n_heads + + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length_q), + (k * scale).view(bs * self.n_heads, ch, length_k), + ) # More stable with f16 than dividing afterwards + + # Duplicate mask n_heads times + mask = mask.repeat_interleave(self.n_heads, dim=0) + assert mask.shape == weight.shape + max_neg_value = -float('inf') + weight = weight.masked_fill(~mask, max_neg_value) + + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # When we use an arbitrary mask, there is a possibility that we get nans in softmax. + # In this case, use nan_to_num to make it a non-nan number. + weight = weight.nan_to_num_() + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length_k)) + # We also return weight here for attention visualization. + return a.reshape(bs, -1, length_q), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVStableMaskedAttention(nn.Module): + """ + A module which performs QKV attention. + Attention mask is accepted as input. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, q, k, v, mask): + r""" + Apply QKV attention with attention mask. + + Args: + q: an [N x d x n_seq1] of queries. + k: an [N x d x n_seq2] of keys. + v: an [N x d x n_seq2] of values. + mask: Attention mask of size N x n_seq1 x n_seq2 + + Returns: an [N x d x n_seq1] tensor after attention. + """ + + bs, width, length_q = q.shape + _, _, length_k = k.shape + + assert width % self.n_heads == 0 + ch = width // self.n_heads + + q = q.view(bs * self.n_heads, ch, length_q) + k = k.view(bs * self.n_heads, ch, length_k) + + # Forming attention mask + mask = mask.repeat_interleave(self.n_heads, dim=0) + + weight = StableMaskedAttentionOp.apply(q, k, ~mask) + + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length_k)) + # We also return weight here for attention visualization. + return a.reshape(bs, -1, length_q), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class SelfAttentionPooling(nn.Module): + """ + Implementation of SelfAttentionPooling + Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition + https://arxiv.org/pdf/2008.01077v1.pdf + Taken from: https://gist.github.com/pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 + """ + + def __init__(self, input_dim): + super(SelfAttentionPooling, self).__init__() + self.W = nn.Linear(input_dim, 1) + + def forward(self, batch_rep): + """ + input: + batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension + + attention_weight: + att_w : size (N, T, 1) + + return: + utter_rep: size (N, H) + """ + softmax = nn.functional.softmax + att_w = softmax(self.W(batch_rep).squeeze(-1), dim=1).unsqueeze(-1) + utter_rep = torch.sum(batch_rep * att_w, dim=1) + + return utter_rep diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention_alt.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention_alt.py new file mode 100644 index 000000000000..8927226c818e --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/attention_alt.py @@ -0,0 +1,321 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from: +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py +""" +import math + +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + +USE_ALT = False + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + + +# Stable attention +class StableAttentionOp(torch.autograd.Function): + # This function defines the attention weight computation in a stable way + # The idea is to scale the gradients of weight matrix by the maximum absolute value. + # In case of overflow, this will prevent weight gradients from exploding. + # In case of underflow, since we clipped the scale to 1e-4, this will prevent underflow. + + @staticmethod + def forward(ctx, q, k): + w = torch.einsum('ncq,nck->nqk', q, k / math.sqrt(k.shape[1])).softmax(dim=2) + ctx.save_for_backward(q, k, w) + return w + + @staticmethod + def backward(ctx, dw): + q, k, w = ctx.saved_tensors + + s = dw.detach().norm(float('inf'), dim=[1, 2], keepdim=True).clip(min=1e-4) + dw = dw / s + + # Due to softmax, w is fp32, making db fp32. + # Type casting is required for amp to work. + db = torch._softmax_backward_data(grad_output=dw, output=w, dim=2, input_dtype=dw.dtype).to(q.dtype) + s = s / math.sqrt(k.shape[1]) + + dq = torch.einsum('nck,nqk->ncq', k, db) * s + dk = torch.einsum('ncq,nqk->nck', q, db) * s + + return dq, dk + + +class QKVStableAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + + # Reshaping q and k + # try: + # q = q.view(bs * self.n_heads, ch, length) + # k = k.view(bs * self.n_heads, ch, length) + # except Exception: + q = q.reshape(bs * self.n_heads, ch, length) + k = k.reshape(bs * self.n_heads, ch, length) + + weight = StableAttentionOp.apply(q, k) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class StableMaskedAttentionOp(torch.autograd.Function): + # Robust attention operation in case of masked attention + @staticmethod + @custom_fwd + def forward(ctx, q, k, mask): + max_neg_value = -float('inf') + w = torch.einsum('ncq,nck->nqk', q, k / math.sqrt(k.shape[1])) + w = w.masked_fill(mask, max_neg_value) + w = w.softmax(dim=2) + + # When we use an arbitrary mask, there is a possibility that we get nans in softmax. + # In this case, use nan_to_num to make it a stable number. + # w = w.nan_to_num_() + ctx.save_for_backward(q, k, w, mask) + return w + + @staticmethod + @custom_bwd + def backward(ctx, dw): + q, k, w, mask = ctx.saved_tensors + max_neg_value = -torch.finfo(q.dtype).max + s = dw.detach().norm(float('inf'), dim=[1, 2], keepdim=True).clip(min=1e-4) + dw = dw / s + db = torch._softmax_backward_data(grad_output=dw, output=w, dim=2, input_dtype=dw.dtype) + + # Masking db + db_in = db.clone().masked_fill_(mask, 0) + + s = s / math.sqrt(k.shape[1]) + dq = torch.einsum('nck,nqk->ncq', k, db_in) * s + dk = torch.einsum('ncq,nqk->nck', q, db_in) * s + + # These are dummy derivatives since mask is a constant + dmask = (max_neg_value - w) * db.clone() * s + + return dq, dk, dmask + + +class QKVMaskedAttention(nn.Module): + """ + A module which performs QKV attention. + Attention mask is accepted as input. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, q, k, v, mask): + r""" + Apply QKV attention with attention mask. + + Args: + q: an [N x d x n_seq1] of queries. + k: an [N x d x n_seq2] of keys. + v: an [N x d x n_seq2] of values. + mask: Attention mask of size N x n_seq1 x n_seq2 + + Returns: an [N x d x n_seq1] tensor after attention. + """ + + bs, width, length_q = q.shape + _, _, length_k = k.shape + + assert width % self.n_heads == 0 + ch = width // self.n_heads + + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length_q), + (k * scale).view(bs * self.n_heads, ch, length_k), + ) # More stable with f16 than dividing afterwards + + # Duplicate mask n_heads times + # mask = mask.repeat_interleave(self.n_heads, dim=0) + mask = mask.unsqueeze(0).repeat(self.n_heads, 1, 1, 1).transpose(0, 1).flatten(0, 1) + assert mask.shape == weight.shape + max_neg_value = -float('inf') + weight = weight.masked_fill(~mask, max_neg_value) + + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + + # When we use an arbitrary mask, there is a possibility that we get nans in softmax. + # In this case, use nan_to_num to make it a non-nan number. + # weight = weight.nan_to_num_() + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length_k)) + # We also return weight here for attention visualization. + return a.reshape(bs, -1, length_q), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVStableMaskedAttention(nn.Module): + """ + A module which performs QKV attention. + Attention mask is accepted as input. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, q, k, v, mask): + r""" + Apply QKV attention with attention mask. + + Args: + q: an [N x d x n_seq1] of queries. + k: an [N x d x n_seq2] of keys. + v: an [N x d x n_seq2] of values. + mask: Attention mask of size N x n_seq1 x n_seq2 + + Returns: an [N x d x n_seq1] tensor after attention. + """ + + bs, width, length_q = q.shape + _, _, length_k = k.shape + + assert width % self.n_heads == 0 + ch = width // self.n_heads + + q = q.view(bs * self.n_heads, ch, length_q) + k = k.view(bs * self.n_heads, ch, length_k) + + # Forming attention mask + # mask = mask.repeat_interleave(self.n_heads, dim=0) + mask = mask.unsqueeze(0).repeat(self.n_heads, 1, 1, 1).transpose(0, 1).flatten(0, 1) + + weight = StableMaskedAttentionOp.apply(q, k, ~mask) + + a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length_k)) + # We also return weight here for attention visualization. + return a.reshape(bs, -1, length_q), weight + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class SelfAttentionPooling(nn.Module): + """ + Implementation of SelfAttentionPooling + Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition + https://arxiv.org/pdf/2008.01077v1.pdf + Taken from: https://gist.github.com/pohanchi/c77f6dbfbcbc21c5215acde4f62e4362 + """ + + def __init__(self, input_dim): + super(SelfAttentionPooling, self).__init__() + self.W = nn.Linear(input_dim, 1) + + def forward(self, batch_rep): + """ + input: + batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension + + attention_weight: + att_w : size (N, T, 1) + + return: + utter_rep: size (N, H) + """ + softmax = nn.functional.softmax + att_w = softmax(self.W(batch_rep).squeeze(-1), dim=1).unsqueeze(-1) + utter_rep = torch.sum(batch_rep * att_w, dim=1) + + return utter_rep diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py new file mode 100644 index 000000000000..445c3c7a98de --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py @@ -0,0 +1,905 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from: +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py +""" +import math +from abc import abstractmethod + +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange + +from nemo.collections.multimodal.modules.imagen.diffusionmodules import attention_alt + +if attention_alt.USE_ALT: + from nemo.collections.multimodal.modules.imagen.diffusionmodules.attention_alt import ( + QKVAttention, + QKVMaskedAttention, + QKVStableAttention, + QKVStableMaskedAttention, + ) +else: + from nemo.collections.multimodal.modules.imagen.diffusionmodules.attention import ( + QKVAttention, + QKVMaskedAttention, + QKVStableAttention, + QKVStableMaskedAttention, + ) +from nemo.collections.multimodal.modules.imagen.diffusionmodules.layers import ( + Downsample, + Upsample, + UpsampleLearnable, + conv_nd, + linear, + normalization, + zero_module, +) + + +def check_cuda(): + if not th.cuda.is_available(): + raise RuntimeError('CUDA is not available') + cur_device = th.cuda.current_device() + dprops = th.cuda.get_device_properties(cur_device) + + is_sm75 = dprops.major == 7 and dprops.minor == 5 + is_sm8x = dprops.major == 8 and dprops.minor >= 0 + is_sm90 = dprops.major == 9 and dprops.minor >= 0 + + return is_sm8x or is_sm75 or is_sm90 + + +try: + from flash_attn import flash_attn_varlen_func, flash_attn_varlen_kvpacked_func + + flash_attn_installed = check_cuda() +except ImportError: + flash_attn_installed = False + + +class TextConditionedBlock(nn.Module): + r""" + Any module where forward() takes text embeddings as arguments. + """ + + @abstractmethod + def forward(self, x, text_emb, text_mask): + """ + Apply the module to `x` given `text_emb` text embedding and 'text_mask' text valid mask. + """ + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class ConditionalSequential(nn.Sequential, TimestepBlock, TextConditionedBlock): + r""" + A sequential module that accepts timestep embeddings, text embedding and text mask in addition to the input x. + Depending on the type of block, we either pass timestep embedding or text embeddings as inputs. + """ + + def forward(self, x, emb, text_emb, text_mask): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, TextConditionedBlock): + x = layer(x, text_emb, text_mask) + else: + x = layer(x) + return x + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + learnable_upsampling=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + if learnable_upsampling: + upsample_fn = UpsampleLearnable + else: + upsample_fn = Upsample + + if up: + self.h_upd = upsample_fn(channels, False, dims) + self.x_upd = upsample_fn(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class EfficientResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + Follow Figure A.27 in Imagen Paper. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + out_channels=None, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + skip_connection_scaling=False, + ): + super().__init__() + + out_channels = out_channels or channels + + self.use_scale_shift_norm = use_scale_shift_norm + self.use_checkpoint = use_checkpoint + + self.in_layers = nn.Sequential( + normalization(channels), nn.SiLU(), conv_nd(dims, channels, out_channels, 3, padding=1) + ) + + self.emb_layers = nn.Sequential( + nn.SiLU(), nn.Linear(emb_channels, 2 * out_channels if use_scale_shift_norm else out_channels,), + ) + + self.out_layers = nn.Sequential( + normalization(out_channels), + nn.SiLU(), + zero_module(conv_nd(dims, out_channels, out_channels, 3, padding=1)), + ) + + self.shortcut = conv_nd(dims, channels, out_channels, 1) + self.shortcut_scale = 1 / math.sqrt(2) if skip_connection_scaling else 1 + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + def _forward(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + return h + self.shortcut(x) * self.shortcut_scale + + +class Block(nn.Module): + def __init__( + self, + channels, + emb_channels, + out_channels=None, + use_scale_shift_norm=True, + num_resblocks=2, + attention_type=None, + text_embed_dim=0, + stable_attention=True, + flash_attention=False, + num_head_channels=-1, + num_heads=8, + dims=2, + use_checkpoint=False, + skip_connection_scaling=False, + ): + super().__init__() + + out_channels = out_channels or channels + + self.attention_type = attention_type + self.text_embed_dim = text_embed_dim + + blocks = [ + EfficientResBlock( + channels, + emb_channels, + out_channels=out_channels, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + skip_connection_scaling=skip_connection_scaling, + ) + ] + + blocks += [ + EfficientResBlock( + out_channels, + emb_channels, + out_channels=out_channels, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + skip_connection_scaling=skip_connection_scaling, + ) + for _ in range(num_resblocks - 1) + ] + + self.blocks = nn.ModuleList(blocks) + + # Attention blocks + # Self - Self-attention blocks + # fused - Single attention layer for fusing self and cross attention. + if self.attention_type is not None: + assert self.attention_type in ('self', 'cross', 'fused', 'stacked') + attention_kwargs = dict() + + if self.attention_type == 'self': + attention_fn = SelfAttentionBlock + elif self.attention_type == 'cross': + attention_fn = CrossAttentionBlock + attention_kwargs['context_dim'] = self.text_embed_dim + elif self.attention_type == 'stacked': + attention_fn = StackedCrossAttentionBlock + attention_kwargs['context_dim'] = self.text_embed_dim + else: + attention_fn = FusedCrossAttentionBlock + attention_kwargs['context_dim'] = self.text_embed_dim + + self.attention_layer = attention_fn( + out_channels, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_checkpoint=use_checkpoint, + stable_attention=stable_attention, + flash_attention=flash_attention, + **attention_kwargs, + ) + + @abstractmethod + def forward(self, x, emb, text_embed=None, text_mask=None): + pass + + +class DBlock(Block): + def __init__( + self, + channels, + emb_channels, + out_channels=None, + use_scale_shift_norm=True, + conv_down=True, + stride=2, + num_resblocks=2, + attention_type=None, + text_embed_dim=0, + stable_attention=True, + flash_attention=False, + num_head_channels=-1, + num_heads=8, + dims=2, + use_checkpoint=False, + skip_connection_scaling=False, + ): + super().__init__( + channels, + emb_channels, + out_channels=out_channels, + use_scale_shift_norm=use_scale_shift_norm, + num_resblocks=num_resblocks, + attention_type=attention_type, + text_embed_dim=text_embed_dim, + stable_attention=stable_attention, + flash_attention=flash_attention, + num_head_channels=num_head_channels, + num_heads=num_heads, + dims=dims, + use_checkpoint=use_checkpoint, + skip_connection_scaling=skip_connection_scaling, + ) + + self.conv_down = conv_down + if self.conv_down: + # self.conv = nn.Conv2d(channels, channels, 3, stride=stride, padding=1) + self.conv = nn.Conv2d(channels, channels, 4, stride=stride, padding=1) + + def forward(self, x, emb, text_embed=None, text_mask=None): + if self.conv_down: + x = self.conv(x) + + for block in self.blocks: + x = block(x, emb) + + if self.attention_type in ('cross', 'fused', 'stacked'): + x = self.attention_layer(x, text_embed, text_mask) + elif self.attention_type == 'self': + x = self.attention_layer(x) + + return x + + +class UBlock(Block): + def __init__( + self, + channels, + emb_channels, + out_channels=None, + use_scale_shift_norm=True, + conv_up=True, + stride=2, + num_resblocks=2, + attention_type=None, + text_embed_dim=0, + stable_attention=True, + flash_attention=False, + num_head_channels=-1, + num_heads=8, + dims=2, + use_checkpoint=False, + skip_connection_scaling=False, + ): + super().__init__( + channels, + emb_channels, + out_channels=out_channels, + use_scale_shift_norm=use_scale_shift_norm, + num_resblocks=num_resblocks, + attention_type=attention_type, + text_embed_dim=text_embed_dim, + stable_attention=stable_attention, + flash_attention=flash_attention, + num_head_channels=num_head_channels, + num_heads=num_heads, + dims=dims, + use_checkpoint=use_checkpoint, + skip_connection_scaling=skip_connection_scaling, + ) + + self.conv_up = conv_up + if self.conv_up: + self.conv = nn.ConvTranspose2d(out_channels, out_channels, 4, stride, 1) + + def forward(self, x, emb, text_embed=None, text_mask=None): + for block in self.blocks: + x = block(x, emb) + + if self.attention_type in ('cross', 'fused', 'stacked'): + x = self.attention_layer(x, text_embed, text_mask) + elif self.attention_type == 'self': + x = self.attention_layer(x) + + if self.conv_up: + x = self.conv(x) + + return x + + +class FusedCrossAttentionBlock(TextConditionedBlock): + """ + An attention block that fuses self-attention and cross-attention + in a single block. + """ + + def __init__( + self, + channels, + context_dim, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + stable_attention=True, + flash_attention=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.flash_attention = flash_attention + self.norm = normalization(channels) + self.norm_context = normalization(context_dim) + self.norm_self = normalization(channels) + + # For image features + self.q = conv_nd(1, channels, channels, 1) + + # For context + self.kv_context = conv_nd(1, context_dim, channels * 2, 1) + + # For spatial + self.kv_self = conv_nd(1, channels, channels * 2, 1) + + if flash_attention: + assert flash_attn_installed, "FlashAttention is not installed." + assert not stable_attention, "FlashAttention doesn't support the stable form." + + elif stable_attention: + self.attention = QKVStableMaskedAttention(self.num_heads) + else: + self.attention = QKVMaskedAttention(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, context, mask): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x, context, mask) + else: + return self._forward(x, context, mask) + + def _forward(self, x, context, mask): + + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + + q = self.q(self.norm(x)) + + # Key-value pairs for self-attention + kv_self = self.kv_self(self.norm_self(x)) + k_self, v_self = kv_self.chunk(2, dim=1) + k_self = k_self.contiguous() + v_self = v_self.contiguous() + + # Key-value pairs for cross-attention + context = th.permute(context, (0, 2, 1)) + context_n = self.norm_context(context) + kv_context = self.kv_context(context_n) + k_context, v_context = kv_context.chunk(2, dim=1) + k_context = k_context.contiguous() + v_context = v_context.contiguous() + + # Appending key-value pairs + k_full = th.cat([k_self, k_context], dim=2) + v_full = th.cat([v_self, v_context], dim=2) + + if self.flash_attention: + # q: b (h d) s, k_context: b (h d) s + batch_size = q.shape[0] + max_seqlen_q, max_seqlen_k = q.shape[2], q.shape[2] + k_context.shape[2] + q = rearrange(q, 'b (h d) s -> (b s) h d', h=self.num_heads) + + mask_self = th.ones((batch_size, max_seqlen_q), device=q.device, dtype=th.bool) + mask_context = mask.bool() + mask_full = th.cat([mask_self, mask_context], dim=1) + + k_full_unpadded = k_full.transpose(1, 2)[mask_full] + total_k = k_full_unpadded.shape[0] + k_full_unpadded = k_full_unpadded.view(total_k, self.num_heads, -1) + + v_full_unpadded = v_full.transpose(1, 2)[mask_full] + v_full_unpadded = v_full_unpadded.view(total_k, self.num_heads, -1) + + # (b s) t h d + kv_full_unpadded = th.stack([k_full_unpadded, v_full_unpadded], dim=1) + + cu_seqlens_q = th.arange( + 0, (batch_size + 1) * max_seqlen_q, step=max_seqlen_q, dtype=th.int32, device=q.device + ) + cu_seqlens_k = th.zeros((batch_size + 1), dtype=th.int32, device=k_full.device) + cu_seqlens_k[1:] = th.cumsum(mask.sum(dim=1), dim=0) + cu_seqlens_k += cu_seqlens_q + + out = flash_attn_varlen_kvpacked_func( + q, kv_full_unpadded, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0 + ) + h = rearrange(out, '(b s) h d -> b (h d) s', b=batch_size, h=self.num_heads) + else: + # Computing mask for self attention + mask_self = th.ones(k_self.shape[0], q.shape[2], k_self.shape[2], device=mask.device) + + # Mask for cross attention + mask_context = mask.view(mask.shape[0], 1, mask.shape[1]) + mask_context = mask_context.repeat(1, q.shape[2], 1) + + # Fused mask + mask_full = th.cat([mask_self, mask_context], dim=2) + mask_full = mask_full.to(th.bool) + + h, _ = self.attention(q, k_full, v_full, mask_full) + + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class SelfAttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + stable_attention=False, + flash_attention=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.flash_attention = flash_attention + if flash_attention: + assert flash_attn_installed, "FlashAttention is not installed." + assert not stable_attention, "FlashAttention doesn't support the stable form." + elif stable_attention: + self.attention = QKVStableAttention(self.num_heads) + else: + self.attention = QKVAttention(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x) + else: + return self._forward(x) + + def _forward(self, x): + + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + + if self.flash_attention: + # qkv shape: (b, (3 h d) s), need to reshape to (b, s, h, d) for each q, k, v + b, _, _ = qkv.shape + q, k, v = qkv.chunk(3, dim=1) + max_seqlen_q, max_seqlen_k = q.shape[2], k.shape[2] + q = rearrange(q, 'b (h d) s -> (b s) h d', h=self.num_heads) + k = rearrange(k, 'b (h d) s -> (b s) h d', h=self.num_heads) + v = rearrange(v, 'b (h d) s -> (b s) h d', h=self.num_heads) + cu_seqlens_q = th.arange(0, (b + 1) * max_seqlen_q, step=max_seqlen_q, dtype=th.int32, device=q.device) + cu_seqlens_k = th.arange(0, (b + 1) * max_seqlen_k, step=max_seqlen_k, dtype=th.int32, device=k.device) + h = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0) + h = rearrange(h, '(b s) h d -> b (h d) s', b=b, h=self.num_heads) + else: + h, _ = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +######################################################################### +# These are the attention blocks as implemented by Stable Diffusion +# https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L196 + + +class CrossAttentionBlock(TextConditionedBlock): + """ + An attention block that allows spatial positions to attend to context. + In our case, context is the token-wise text embeddings. + """ + + def __init__( + self, + channels, + context_dim, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + stable_attention=True, + flash_attention=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.norm_context = normalization(context_dim) + self.flash_attention = flash_attention + # For image features + self.q = conv_nd(1, channels, channels, 1) + + # For context + self.kv = conv_nd(1, context_dim, channels * 2, 1) + + if flash_attention: + assert flash_attn_installed, "FlashAttention is not installed." + assert not stable_attention, "FlashAttention doesn't support the stable form." + elif stable_attention: + self.attention = QKVStableMaskedAttention(self.num_heads) + else: + self.attention = QKVMaskedAttention(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, context, mask): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x, context, mask) + else: + return self._forward(x, context, mask) + + def _forward(self, x, context, mask): + + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + + q = self.q(self.norm(x)) + context = th.permute(context, (0, 2, 1)) + context_n = self.norm_context(context) + kv = self.kv(context_n) + k, v = kv.chunk(2, dim=1) + k = k.contiguous() + v = v.contiguous() + + if self.flash_attention: + batch_size = q.shape[0] + max_seqlen_q, max_seqlen_k = q.shape[2], k.shape[2] + q = rearrange(q, 'b (h d) s -> (b s) h d', h=self.num_heads) + mask = mask.to(th.bool) + k_unpadded = k.transpose(1, 2)[mask] + total_k = k_unpadded.shape[0] + k_unpadded = k_unpadded.view(total_k, self.num_heads, -1) + v_unpadded = v.transpose(1, 2)[mask] + v_unpadded = v_unpadded.view(total_k, self.num_heads, -1) + kv_unpadded = th.stack([k_unpadded, v_unpadded], dim=1) + cu_seqlens_q = th.arange( + 0, (batch_size + 1) * max_seqlen_q, step=max_seqlen_q, dtype=th.int32, device=q.device + ) + cu_seqlens_k = th.zeros((batch_size + 1), dtype=th.int32, device=q.device) + cu_seqlens_k[1:] = th.cumsum(mask.sum(dim=1), dim=0) + + out = flash_attn_varlen_kvpacked_func( + q, kv_unpadded, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, 0.0 + ) + h = rearrange(out, '(b s) h d -> b (h d) s', b=batch_size, h=self.num_heads) + else: + # Computing mask for cross attention + mask = mask.view(mask.shape[0], 1, mask.shape[1]) + mask = mask.repeat(1, q.shape[-1], 1) + mask = mask.to(th.bool) + + h, _ = self.attention(q, k, v, mask) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.norm = normalization(dim) + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim)) + + def forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + + h = self.norm(x) + + # Reshape so that the channel dim moves to last + # Linear function operates on the last dimension + h = th.permute(h, (0, 2, 1)) + + h = self.net(h) + + # Permute it back + h = th.permute(h, (0, 2, 1)) + + return (x + h).reshape(b, c, *spatial) + + +class StackedCrossAttentionBlock(TextConditionedBlock): + """ + An attention block that stacks self-attention and cross-attention layers + in a single block. + """ + + def __init__( + self, + channels, + context_dim, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + stable_attention=True, + flash_attention=False, + ): + super().__init__() + self.proj_in = conv_nd(2, channels, channels, 1) + self.norm = normalization(channels) + self.use_checkpoint = use_checkpoint + + self.self_attention_block = SelfAttentionBlock( + channels=channels, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_checkpoint=use_checkpoint, + stable_attention=stable_attention, + flash_attention=flash_attention, + ) + + self.cross_attention_block = CrossAttentionBlock( + channels=channels, + context_dim=context_dim, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_checkpoint=use_checkpoint, + stable_attention=stable_attention, + flash_attention=flash_attention, + ) + + self.ff = FeedForward(dim=channels, glu=True) + self.proj_out = zero_module(conv_nd(2, channels, channels, 1)) + + def forward(self, x, context, mask): + if self.use_checkpoint: + return checkpoint.checkpoint(self._forward, x, context, mask) + else: + return self._forward(x, context, mask) + + def _forward(self, x, context, mask): + + h = self.norm(x) + h = self.proj_in(h) + + h = self.self_attention_block(h) + h = self.cross_attention_block(h, context, mask) + h = self.ff(h) + + h = self.proj_out(h) + return h + x diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/embs.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/embs.py new file mode 100644 index 000000000000..12ba4941041e --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/embs.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn as nn +from einops import rearrange + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +class UnLearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + super().__init__() + self.dim = dim + self.max_period = max_period + print(f'Unlearned Timestep Embedding Schedule: dim={dim}, max_period={max_period}') + + def forward(self, timesteps): + dim = self.dim + half = dim // 2 + max_period = self.max_period + dtype = timesteps.dtype + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + args = args.to(dtype=dtype) + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py new file mode 100644 index 000000000000..72e70250f0d7 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/layers.py @@ -0,0 +1,240 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (c) 2021 OpenAI +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Brought from: +https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py + +Various utilities for neural networks. +""" + +import math + +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from apex.contrib.group_norm import GroupNorm + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels, act=""): + """ + Make a standard normalization layer. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm(32, channels, act=act) + + +def timestep_embedding(timesteps, dim, max_period=10000, dtype=th.float32): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + args = args.to(dtype=dtype) + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +# Native ADM nearest neighbor upsampling +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class UpsampleLearnable(nn.Module): + """ + Upsampling based on ConvTranspose2d. This is needed for bfloat support. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + + if self.dims == 2: + self.conv = nn.ConvTranspose2d(self.channels, self.out_channels, 4, 2, 1) + elif self.dims == 3: + self.conv = nn.ConvTranspose3d( + self.channels, self.out_channels, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1) + ) + else: + raise ValueError('Upsampling support only for 2D and 3D') + + def forward(self, x): + assert x.shape[1] == self.channels + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/nets.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/nets.py new file mode 100644 index 000000000000..96b1a5dfeefc --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/nets.py @@ -0,0 +1,698 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.collections.multimodal.modules.imagen.diffusionmodules.attention import SelfAttentionPooling +from nemo.collections.multimodal.modules.imagen.diffusionmodules.blocks import ( + ConditionalSequential, + DBlock, + FusedCrossAttentionBlock, + ResBlock, + StackedCrossAttentionBlock, + UBlock, +) +from nemo.collections.multimodal.modules.imagen.diffusionmodules.embs import ( + LearnedSinusoidalPosEmb, + UnLearnedSinusoidalPosEmb, +) +from nemo.collections.multimodal.modules.imagen.diffusionmodules.layers import Downsample +from nemo.collections.multimodal.modules.imagen.diffusionmodules.layers import UpsampleLearnable as Upsample +from nemo.collections.multimodal.modules.imagen.diffusionmodules.layers import linear, normalization, zero_module + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding used for Imagen Base and SR model. + + :param embed_dim: Dimension of embeddings. Also used to calculate the number of channels in ResBlock. + :param image_size: Input image size. Used to calculate where to inject attention layers in UNet. + :param channels: Input channel number, defaults to 3. + :param text_embed_dim: Dimension of conditioned text embedding. Different text encoders and different model versions have different values, defaults to 512 + :param num_res_blocks: Number of ResBlock in each level of UNet, defaults to 3. + :param channel_mult: Used with embed_dim to calculate the number of channels for each level of UNet, defaults to [1, 2, 3, 4] + :param num_attn_heads: The number of heads in the attention layer, defaults to 4. + :param per_head_channels: The number of channels per attention head, defaults to 64. + :param cond_dim: Dimension of Conditioning projections, defaults to 512. + :param attention_type: Type of attention layer, defaults to 'fused'. + :param feature_pooling_type: Type of pooling, defaults to 'attention'. + :param learned_sinu_pos_emb_dim: Dimension of learned time positional embedding. 0 for unlearned timestep embeddings. Defaults to 16 + :param attention_resolutions: List of resolutions to inject attention layers. Defaults to [8, 16, 32] + :param dropout: The rate of dropout, defaults to 0. + :param use_null_token: Whether to create a learned null token for attention, defaults to False. + :param init_conv_kernel_size: Initial Conv kernel size, defaults to 3. + :param gradient_checkpointing: Whether to use gradient checkpointing, defaults to False. + :param scale_shift_norm: Whether to use scale shift norm, defaults to False. + :param stable_attention: Whether to use numerically-stable attention calculation, defaults to True. + :param flash_attention: Whether to use flash attention calculation, defaults to False. + :param resblock_updown: Whether to use ResBlock or Downsample/Upsample, defaults to False. + :param resample_with_conv: When resblock_updown=False, whether to use conv in addition to Pooling&ConvTranspose. Defaults to True. + :param low_res_cond: Whether conditioned on low-resolution input, used for SR model. Defaults to False. + :param noise_cond_aug: Whether to add noise conditioned augmentation with low-resolution input. Defaults to False. + """ + + def __init__( + self, + embed_dim, # Dimension of embeddings. Also used to calculate the number of channels in ResBlock + image_size, # Input image size. Used to calculate where to inject attention layers in UNet + channels=3, # Input channel number + text_embed_dim=512, # Dimension of conditioned text embedding. Different text encoders and different model versions have different values + num_res_blocks=3, # Number of ResBlock in each level of UNet + channel_mult=[1, 2, 3, 4], # Used with embed_dim to calculate the number of channels for each level of UNet + num_attn_heads=4, # The number of heads in the attention layer + per_head_channels=64, # The number of channels per attention head + cond_dim=512, # Dimension of Conditioning projections + attention_type='fused', # Type of attention layer + feature_pooling_type='attention', # Type of pooling + learned_sinu_pos_emb_dim=16, # Dimension of learned time positional embedding. 0 for unlearned timestep embeddings. + attention_resolutions=[8, 16, 32], # List of resolutions to inject attention layers + dropout=False, # The rate of dropout + use_null_token=False, # Whether to create a learned null token for attention + init_conv_kernel_size=3, # Initial Conv kernel size. imagen_pytorch uses 7 + gradient_checkpointing=False, # Whether to use gradient checkpointing + scale_shift_norm=True, # Whether to use scale shift norm + stable_attention=True, # Whether to use numerically-stable attention calculation + flash_attention=False, # Whether to use flash attention calculation + resblock_updown=False, # Whether to use ResBlock or Downsample/Upsample + resample_with_conv=True, # When resblock_updown=False, whether to use conv in addition to Pooling&ConvTranspose + low_res_cond=False, + noise_cond_aug=False, + ): + super().__init__() + + # Attention Class + if attention_type == 'stacked': + attention_fn = StackedCrossAttentionBlock + elif attention_type == 'fused': + attention_fn = FusedCrossAttentionBlock + else: + raise ValueError('Attention {} not defined'.format(attention_type)) + + # Time embedding for log(snr) noise from continous version + time_embed_dim = embed_dim * 4 + assert learned_sinu_pos_emb_dim >= 0 + if learned_sinu_pos_emb_dim > 0: + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 + self.time_embed = nn.Sequential( + sinu_pos_emb, + nn.Linear(sinu_pos_emb_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + else: + # Unlearned Time Embedding + sinu_pos_emb = UnLearnedSinusoidalPosEmb(embed_dim) + self.time_embed = nn.Sequential( + sinu_pos_emb, linear(embed_dim, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim) + ) + + # Pooling + assert feature_pooling_type == 'attention' or feature_pooling_type == 'mean' + self.feature_pooling_type = feature_pooling_type + if feature_pooling_type == 'attention': + self.attention_pooling = nn.Sequential( + SelfAttentionPooling(input_dim=text_embed_dim), + nn.LayerNorm(text_embed_dim), + nn.Linear(text_embed_dim, cond_dim), + ) + + # Context Projections + self.text_to_cond = linear(text_embed_dim, cond_dim) + self.to_text_non_attn_cond = nn.Sequential( + nn.LayerNorm(cond_dim), + nn.Linear(cond_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + # Register for Null Token + if use_null_token: + self.null_text_embedding = nn.Parameter(torch.randn(1, 1, cond_dim, dtype=self.text_to_cond.weight.dtype)) + self.use_null_token = use_null_token + + # Converting attention resolutions to downsampling factor + attention_ds = [] + attention_resolutions = sorted(attention_resolutions) + self.image_size = image_size + for res in attention_resolutions: + attention_ds.append(image_size // int(res)) + + self.low_res_cond = low_res_cond + # Low res noise conditioning augmentation + self.noise_cond_aug = noise_cond_aug + if self.noise_cond_aug: + assert ( + self.low_res_cond + ), 'noise conditioning augmentation should only be enabled when training with low-res cond' + if learned_sinu_pos_emb_dim > 0: + lowres_sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + lowres_sinu_pos_emb_dim = learned_sinu_pos_emb_dim + 1 + else: + lowres_sinu_pos_emb = UnLearnedSinusoidalPosEmb(embed_dim) + lowres_sinu_pos_emb_dim = embed_dim + self.lowres_time_embed = nn.Sequential( + lowres_sinu_pos_emb, + nn.Linear(lowres_sinu_pos_emb_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + # Initial Convolution + in_channels = 2 * channels if low_res_cond else channels + init_dim = embed_dim * channel_mult[0] + self.init_conv = ConditionalSequential( + nn.Conv2d(in_channels, init_dim, init_conv_kernel_size, padding=init_conv_kernel_size // 2) + ) + + if isinstance(num_res_blocks, int): + res_blocks_list = [num_res_blocks] * len(channel_mult) + else: + res_blocks_list = num_res_blocks + # UNet Init + # Downsampling Layers + # We use Conv2D for UNet + CONV_DIM = 2 + ch = init_dim + ds = 1 + self.input_blocks = nn.ModuleList([self.init_conv]) + num_input_block_channels = [ch] + for level, mult in enumerate(channel_mult): + num_res_blocks = res_blocks_list[level] + for _ in range(num_res_blocks): + out_channels = mult * embed_dim + layers = [ + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + learnable_upsampling=True, + ) + ] + ch = out_channels + if ds in attention_ds: + layers.append( + attention_fn( + channels=ch, + num_heads=num_attn_heads, + num_head_channels=per_head_channels, + use_checkpoint=gradient_checkpointing, + stable_attention=stable_attention, + flash_attention=flash_attention, + context_dim=cond_dim, + ) + ) + self.input_blocks.append(ConditionalSequential(*layers)) + num_input_block_channels.append(ch) + is_last_level = level == len(channel_mult) - 1 + if not is_last_level: + # DownSampling + self.input_blocks.append( + ConditionalSequential( + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=ch, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + down=True, + learnable_upsampling=True, + ) + if resblock_updown + else Downsample(channels=ch, use_conv=resample_with_conv, dims=CONV_DIM, out_channels=ch,) + ) + ) + num_input_block_channels.append(ch) + ds *= 2 + + # Middle Layers + self.middle_block = ConditionalSequential( + # Mid Block 1 + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + learnable_upsampling=True, + ), + # Attention Layer + attention_fn( + channels=ch, + num_heads=num_attn_heads, + num_head_channels=per_head_channels, + use_checkpoint=gradient_checkpointing, + stable_attention=stable_attention, + flash_attention=flash_attention, + context_dim=cond_dim, + ), + # Mid Block 2 + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + learnable_upsampling=True, + ), + ) + + # Upsampling Layers + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + num_res_blocks = res_blocks_list[level] + for i in range(num_res_blocks + 1): + ich = num_input_block_channels.pop() + out_channels = embed_dim * mult + layers = [ + ResBlock( + channels=ch + ich, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + learnable_upsampling=True, + ) + ] + ch = out_channels + + if ds in attention_ds: + layers.append( + attention_fn( + channels=ch, + num_heads=-1, # TODO + num_head_channels=per_head_channels, + use_checkpoint=gradient_checkpointing, + stable_attention=stable_attention, + flash_attention=flash_attention, + context_dim=cond_dim, + ) + ) + is_last_block = i == num_res_blocks + if level and is_last_block: + layers.append( + ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=ch, + dims=CONV_DIM, + use_checkpoint=gradient_checkpointing, + use_scale_shift_norm=scale_shift_norm, + up=True, + learnable_upsampling=True, + ) + if resblock_updown + else Upsample(channels=ch, use_conv=resample_with_conv, dims=CONV_DIM, out_channels=ch) + ) + ds //= 2 + self.output_blocks.append(ConditionalSequential(*layers)) + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(nn.Conv2d(init_dim, channels, init_conv_kernel_size, padding=init_conv_kernel_size // 2)), + ) + + def forward( + self, x, time, text_embed=None, text_mask=None, x_low_res=None, time_low_res=None, + ): + if self.low_res_cond: + assert x_low_res is not None, 'x_low_res cannot be None' + else: + assert x_low_res is None, 'x_low_res cannot be presented' + if self.noise_cond_aug: + assert time_low_res is not None, 'time_low_res cannot be None when training with noise conditioning aug' + else: + assert time_low_res is None, 'time_low_res cannot be presented' + # Concatenating low resolution images + if x_low_res is not None: + if x_low_res.shape != x.shape: + # Upscale if not done in the trainer + _, _, new_height, new_width = x.shape + x_low_res = F.interpolate(x_low_res, (new_height, new_width), mode="bicubic") + x = torch.cat([x, x_low_res], dim=1) + batch_size, device = x.shape[0], x.device + + if x.dtype != time.dtype or time.dtype != text_embed.dtype: + dtype = text_embed.dtype + x = x.to(dtype=dtype) + time = time.to(dtype=dtype) + if x_low_res is not None: + x_low_res = x_low_res.to(dtype=dtype) + if time_low_res is not None: + time_low_res = time_low_res.to(dtype=dtype) + # Time Conditioning + t = self.time_embed(time) + # Add lowres time conditioning + if self.noise_cond_aug: + lowres_t = self.lowres_time_embed(time_low_res) + t += lowres_t + # Text Conditioning + text_cond = self.text_to_cond(text_embed) + + # Context Embedding + # TODO We may want to concat time token here + if self.use_null_token: + # Null Context (Helpful when text_embed is drop) + null_context = self.null_text_embedding.repeat(batch_size, 1, 1) + context_emb = torch.cat([text_cond, null_context], dim=1) + context_mask = torch.cat([text_mask, torch.ones(batch_size, 1).to(device)], dim=1) + else: + context_emb = text_cond + context_mask = text_mask + + # Add pooled text embeddings to the diffusion timestep + # TODO We may only want to calculated the pooled feature based on text token length + if self.feature_pooling_type == 'mean': + pooled_text_cond = text_cond.mean(dim=-2) + elif self.feature_pooling_type == 'attention': + pooled_text_cond = self.attention_pooling(text_embed) + text_hiddens = self.to_text_non_attn_cond(pooled_text_cond) + t += text_hiddens + + h = x + hs = [] + # UNet Forward + for module in self.input_blocks: + h = module(h, t, context_emb, context_mask) + hs.append(h) + h = self.middle_block(h, t, context_emb, context_mask) + for module in self.output_blocks: + h_prev = hs.pop() + h = torch.cat([h, h_prev], dim=1) + h = module(h, t, context_emb, context_mask) + return self.out(h) + + def forward_with_cond_scale(self, *args, text_embed=None, cond_scale=1.0, **kwargs): + logits = self.forward(*args, text_embed=text_embed, **kwargs) + if cond_scale == 1.0: + return logits + null_logits = self.forward(*args, text_embed=torch.zeros_like(text_embed), **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + +class EfficientUNetModel(nn.Module): + """ + The full Efficient UNet model with attention and timestep embedding used for Imagen SR model. + + :param embed_dim: Dimension of embeddings. Also used to calculate the number of channels in ResBlock. + :param image_size: Input image size. Used to calculate where to inject attention layers in UNet. + :param channels: Input channel number, defaults to 3. + :param text_embed_dim: Dimension of conditioned text embedding. Different text encoders and different model versions have different values, defaults to 512 + :param channel_mult: Used with embed_dim to calculate the number of channels for each level of UNet, defaults to [1, 1, 2, 4, 8]. + :param num_attn_heads: The number of heads in the attention layer, defaults to 8. + :param per_head_channels: The number of channels per attention head, defaults to 64. + :param attention_type: Type of attention layer, defaults to 'fused'. + :param atnn_enabled_at: Whether to enable attention at each level, defaults to [0, 0, 0, 0, 1]. + :param feature_pooling_type: Type of pooling, defaults to 'attention'. + :param stride: Stride in ResBlock, defaults to 2. + :param num_resblocks: Used with num_res_blocks to calculate the number of residual blocks at each level of Efficient-UNet. Defaults to [1, 2, 4, 8, 8]. + :param learned_sinu_pos_emb_dim: Dimension of learned time positional embedding. 0 for unlearned timestep embeddings. Defaults to 16 + :param use_null_token: Whether to create a learned null token for attention, defaults to False. + :param init_conv_kernel_size: Initial Conv kernel size, defaults to 3. + :param gradient_checkpointing: Whether to use gradient checkpointing, defaults to False. + :param scale_shift_norm: Whether to use scale shift norm, defaults to False. + :param stable_attention: Whether to use numerically-stable attention calculation, defaults to True. + :param flash_attention: Whether to use flash attention calculation, defaults to False. + :param skip_connection_scaling: Whether to use 1/sqrt(2) scaling for ResBlock skip connection, defaults to False. + :param noise_cond_aug: Whether to add noise conditioned augmentation with low-resolution input. Defaults to False. + """ + + def __init__( + self, + embed_dim, + image_size, + channels=3, + text_embed_dim=512, # Dimension of conditioned text embedding. Different text encoders and different model versions have different values + channel_mult=[ + 1, + 1, + 2, + 4, + 8, + ], # Used with embed_dim to calculate the number of channels for each level of Efficient-UNet + num_attn_heads=8, # The number of heads in the attention layer + per_head_channels=64, # The number of channels per attention head + attention_type='fused', # Type of attention layer + atnn_enabled_at=[0, 0, 0, 0, 1], # Whether to enable attention at each level + feature_pooling_type='attention', # Type of pooling + stride=2, # Stride in ResBlock + num_resblocks=[ + 1, + 2, + 4, + 8, + 8, + ], # Used with num_res_blocks to calculate the number of residual blocks at each level of Efficient-UNet + learned_sinu_pos_emb_dim=16, # Dimension of learned time positional embedding. 0 for unlearned timestep embeddings. + use_null_token=False, # Whether to create a learned null token for attention + init_conv_kernel_size=3, # Initial Conv kernel size. imagen_pytorch uses 7 + gradient_checkpointing=False, # Whether to use gradient checkpointing + scale_shift_norm=True, # Whether to use scale shift norm + stable_attention=True, # Whether to use numerically-stable attention calculation + flash_attention=False, # Whether to use flash attention calculation + skip_connection_scaling=False, # Whether to use 1/sqrt(2) scaling for ResBlock skip connection + noise_cond_aug=False, + ): + + super().__init__() + + self.n_levels = len(channel_mult) + self.image_size = image_size + # Time embedding for log(snr) noise from continous version + time_embed_dim = embed_dim * 4 + assert learned_sinu_pos_emb_dim >= 0 + if learned_sinu_pos_emb_dim > 0: + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 + self.time_embed = nn.Sequential( + sinu_pos_emb, + nn.Linear(sinu_pos_emb_input_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + else: + # Unlearned Time Embedding + sinu_pos_emb = UnLearnedSinusoidalPosEmb(embed_dim) + self.time_embed = nn.Sequential( + sinu_pos_emb, linear(embed_dim, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim) + ) + + self.noise_cond_aug = noise_cond_aug + if self.noise_cond_aug: + if learned_sinu_pos_emb_dim > 0: + lowres_sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) + lowres_sinu_pos_emb_dim = learned_sinu_pos_emb_dim + 1 + else: + lowres_sinu_pos_emb = UnLearnedSinusoidalPosEmb(embed_dim) + lowres_sinu_pos_emb_dim = embed_dim + self.lowres_time_embed = nn.Sequential( + lowres_sinu_pos_emb, + nn.Linear(lowres_sinu_pos_emb_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + cond_dim = text_embed_dim # time_embed_dim + # Pooling + assert feature_pooling_type == 'attention' or feature_pooling_type == 'mean' + self.feature_pooling_type = feature_pooling_type + if feature_pooling_type == 'attention': + self.attention_pooling = nn.Sequential( + SelfAttentionPooling(input_dim=text_embed_dim), + nn.LayerNorm(text_embed_dim), + nn.Linear(text_embed_dim, cond_dim), + ) + + # Context Projections + self.text_to_cond = linear(text_embed_dim, cond_dim) + self.to_text_non_attn_cond = nn.Sequential( + nn.LayerNorm(cond_dim), + nn.Linear(cond_dim, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + # Register for Null Token + if use_null_token: + self.null_text_embedding = nn.Parameter(torch.randn(1, 1, cond_dim, dtype=self.text_to_cond.weight.dtype)) + self.use_null_token = use_null_token + + # Initial Convolution + # Multiply in_channels by 2 because we concatenate with low res inputs. + in_channels = channels * 2 + init_dim = embed_dim * channel_mult[0] + self.init_conv = nn.Conv2d(in_channels, init_dim, init_conv_kernel_size, padding=init_conv_kernel_size // 2) + # Efficient-UNet Init + self.DBlocks = nn.ModuleDict() + self.UBlocks = nn.ModuleDict() + ch = init_dim + for level, mult in enumerate(channel_mult): + # Different level has different num of res blocks + num_resblock = num_resblocks[level] + # Only perform upsample/downsample if it is not the last (deepest) level + is_last_level = level == len(channel_mult) - 1 + level_attention_type = attention_type if atnn_enabled_at[level] else None + + level_key = str(level) # TODO Change to more meaningful naming + self.DBlocks[level_key] = DBlock( + channels=ch, + emb_channels=time_embed_dim, + out_channels=int(mult * embed_dim), + use_scale_shift_norm=scale_shift_norm, + conv_down=not is_last_level, + stride=stride, + num_resblocks=num_resblock, + attention_type=level_attention_type, + text_embed_dim=cond_dim, + num_heads=num_attn_heads, + num_head_channels=per_head_channels, + use_checkpoint=gradient_checkpointing, + stable_attention=stable_attention, + flash_attention=flash_attention, + skip_connection_scaling=skip_connection_scaling, + ) + self.UBlocks[level_key] = UBlock( + channels=int(mult * embed_dim), + emb_channels=time_embed_dim, + out_channels=ch, + use_scale_shift_norm=scale_shift_norm, + conv_up=not is_last_level, + stride=stride, + num_resblocks=num_resblock, + attention_type=level_attention_type, + text_embed_dim=cond_dim, + num_heads=num_attn_heads, + num_head_channels=per_head_channels, + use_checkpoint=gradient_checkpointing, + stable_attention=stable_attention, + flash_attention=flash_attention, + skip_connection_scaling=skip_connection_scaling, + ) + ch = int(mult * embed_dim) + self.out = nn.Conv2d(channel_mult[0] * embed_dim, channels, 1) + + def forward( + self, x, time, text_embed, text_mask, x_low_res, time_low_res=None, + ): + if self.noise_cond_aug: + assert time_low_res is not None, 'time_low_res cannot be None when training with noise conditioning aug' + else: + assert time_low_res is None, 'time_low_res cannot be presented' + + if x.dtype != time.dtype or time.dtype != text_embed.dtype: + dtype = text_embed.dtype + x = x.to(dtype=dtype) + time = time.to(dtype=dtype) + if x_low_res is not None: + x_low_res = x_low_res.to(dtype=dtype) + if time_low_res is not None: + time_low_res = time_low_res.to(dtype=dtype) + + batch_size, device = x.shape[0], x.device + # Time Conditioning + t = self.time_embed(time) + # Text Conditioning + text_cond = self.text_to_cond(text_embed) + # Concatenating low resolution images + if x_low_res.shape != x.shape: + # Upscale if not done in the trainer + _, _, new_height, new_width = x.shape + x_low_res = F.interpolate(x_low_res, (new_height, new_width), mode="bicubic") + x = torch.cat([x, x_low_res], dim=1) + + # Add lowres time conditioning + if self.noise_cond_aug: + lowres_t = self.lowres_time_embed(time_low_res) + t += lowres_t + # Context Embedding + # TODO We may want to concat time token here + if self.use_null_token: + # Null Context (Helpful when text_embed is drop) + null_context = self.null_text_embedding.repeat(batch_size, 1, 1) + context_emb = torch.cat([text_cond, null_context], dim=1) + context_mask = torch.cat([text_mask, torch.ones(batch_size, 1).to(device)], dim=1) + else: + context_emb = text_cond + context_mask = text_mask + + # Add pooled text embeddings to the diffusion timestep + # TODO We may only want to calculated the pooled feature based on text token length + if self.feature_pooling_type == 'mean': + pooled_text_cond = text_cond.mean(dim=-2) + elif self.feature_pooling_type == 'attention': + pooled_text_cond = self.attention_pooling(text_embed) + text_hiddens = self.to_text_non_attn_cond(pooled_text_cond) + t += text_hiddens + + # UNet forward + x = self.init_conv(x) + feats = dict() + for level in range(self.n_levels): + level_key = str(level) + x = self.DBlocks[level_key](x, t, context_emb, context_mask) + # Save feats for UBlocks + if level < self.n_levels - 1: + feats[level_key] = x + for level in range(self.n_levels - 1, -1, -1): + level_key = str(level) + if level < self.n_levels - 1: + x = x + feats[level_key] + x = self.UBlocks[level_key](x, t, context_emb, context_mask) + return self.out(x) + + def forward_with_cond_scale(self, *args, text_embed=None, cond_scale=1.0, **kwargs): + logits = self.forward(*args, text_embed=text_embed, **kwargs) + if cond_scale == 1.0: + return logits + null_logits = self.forward(*args, text_embed=torch.zeros_like(text_embed), **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + +if __name__ == '__main__': + model = UNetModel(embed_dim=512, image_size=64,) + + pytorch_total_params = sum(p.numel() for p in model.parameters()) + print(pytorch_total_params) + + image_batch = torch.rand(4, 3, 64, 64) + text_cond = torch.rand(4, 88, 512) + text_mask = torch.ones(4, 88) + time = torch.ones(4) + + output = model(image_batch, time, text_cond, text_mask,) + + print(output.shape) + + model_sr = EfficientUNetModel(embed_dim=128, image_size=256) + pytorch_total_params = sum(p.numel() for p in model_sr.parameters()) + print(pytorch_total_params) + output = model_sr( + torch.randn(4, 3, 256, 256), + torch.randn(4, 3, 256, 256), + torch.ones(4), + torch.randn(4, 88, 512), + torch.ones(4, 88), + ) + print(output.shape) diff --git a/nemo/collections/multimodal/modules/imagen/encoder/__init__.py b/nemo/collections/multimodal/modules/imagen/encoder/__init__.py new file mode 100644 index 000000000000..aee951313044 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/encoder/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Speech Computer Vision collection" diff --git a/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.json b/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.json new file mode 100644 index 000000000000..3fb4ffdac7f1 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.json @@ -0,0 +1,51 @@ +{ + "architectures": [ + "T5WithLMHeadModel" + ], + "d_ff": 65536, + "d_kv": 128, + "d_model": 1024, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "n_positions": 512, + "num_heads": 128, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "task_specific_params": { + "summarization": { + "early_stopping": true, + "length_penalty": 2.0, + "max_length": 200, + "min_length": 30, + "no_repeat_ngram_size": 3, + "num_beams": 4, + "prefix": "summarize: " + }, + "translation_en_to_de": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to German: " + }, + "translation_en_to_fr": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to French: " + }, + "translation_en_to_ro": { + "early_stopping": true, + "max_length": 300, + "num_beams": 4, + "prefix": "translate English to Romanian: " + } + }, + "vocab_size": 32128 +} diff --git a/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.py b/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.py new file mode 100644 index 000000000000..c660bc0c54f3 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/encoder/t5encoder.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch +from transformers import T5Config, T5EncoderModel, T5Tokenizer + + +class T5Encoder(torch.nn.Module): + def __init__(self, max_seq_len=512, encoder_path=None): + """ + Initialize the T5 Encoder. + + :param max_seq_len: Maximum token length, defaults to 512 + :param encoder_path: Optional if loaded T5 on the disk, defaults to None + """ + super().__init__() + self.max_seq_len = max_seq_len + + self.model_seq_len = 512 + # Initializing T5 model + self.tokenizer = T5Tokenizer.from_pretrained("t5-11b", model_max_length=self.model_seq_len) + + if encoder_path is None: + self.model = T5EncoderModel.from_pretrained("t5-11b", low_cpu_mem_usage=True) + else: + print(f'Load T5 encoder from {encoder_path}') + hard_coded_encoder_weight_location = os.path.join(encoder_path, "t5xxl-encoder.bin") + hard_coded_encoder_config_location = os.path.join(os.path.dirname(__file__), "t5encoder.json") + self.model = T5EncoderModel.from_pretrained( + hard_coded_encoder_weight_location, + config=T5Config.from_json_file(hard_coded_encoder_config_location), + low_cpu_mem_usage=True, + ) + + def encode(self, text_batch, device='cuda'): + ''' + Encode a batch of text to T5 embeddings. + ''' + encoded = self.tokenizer.batch_encode_plus( + text_batch, return_tensors="pt", padding="max_length", max_length=self.model_seq_len, truncation=True + ) + # We expect all the processing is done in GPU. + input_ids = encoded.input_ids.to(device=device) + attn_mask = encoded.attention_mask.to(device=device) + + with torch.no_grad(): + output = self.model(input_ids=input_ids, attention_mask=attn_mask) + encoded_text = output.last_hidden_state.detach() + + encoded_text = encoded_text[:, 0 : self.max_seq_len] + attn_mask = attn_mask[:, 0 : self.max_seq_len] + for bnum in range(encoded_text.shape[0]): + nvalid_elem = attn_mask[bnum].sum().item() + encoded_text[bnum][nvalid_elem:] = 0 + + return encoded_text, attn_mask diff --git a/nemo/collections/multimodal/modules/imagen/sampler/__init__.py b/nemo/collections/multimodal/modules/imagen/sampler/__init__.py new file mode 100644 index 000000000000..aee951313044 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/sampler/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.package_info import __version__ + +# Set collection version equal to NeMo version. +__version = __version__ + +# Authorship. +__author__ = "NVIDIA Corporation" + +# Set collection name. +__description__ = "Speech Computer Vision collection" diff --git a/nemo/collections/multimodal/modules/imagen/sampler/batch_ops.py b/nemo/collections/multimodal/modules/imagen/sampler/batch_ops.py new file mode 100644 index 000000000000..029bbf60ffbc --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/sampler/batch_ops.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Functions for performing operations with broadcasting to the right axis +# +# Example +# input1: tensor of size (N1, N2) +# input2: tensor of size (N1, N2, N3, N4) +# batch_mul(input1, input2) = input1[:, :, None, None] * input2 +# +# If the common dimensions don't match, we raise an assertion error. + + +def common_broadcast(x, y): + ndims1 = x.ndim + ndims2 = y.ndim + + common_ndims = min(ndims1, ndims2) + for axis in range(common_ndims): + assert x.shape[axis] == y.shape[axis], 'Dimensions not equal at axis {}'.format(axis) + + if ndims1 < ndims2: + x = x.reshape(x.shape + (1,) * (ndims2 - ndims1)) + elif ndims2 < ndims1: + y = y.reshape(y.shape + (1,) * (ndims1 - ndims2)) + + return x, y + + +def batch_add(x, y): + x, y = common_broadcast(x, y) + return x + y + + +def batch_mul(x, y): + x, y = common_broadcast(x, y) + return x * y + + +def batch_sub(x, y): + x, y = common_broadcast(x, y) + return x - y + + +def batch_div(x, y): + x, y = common_broadcast(x, y) + return x / y diff --git a/nemo/collections/multimodal/modules/imagen/sampler/continuous_ddpm.py b/nemo/collections/multimodal/modules/imagen/sampler/continuous_ddpm.py new file mode 100644 index 000000000000..2b48f28ce9c9 --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/sampler/continuous_ddpm.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from functools import partial, wraps + +import torch +import torch.nn as nn +from einops import repeat +from torch.special import expm1 + +from nemo.collections.multimodal.parts.utils import randn_like + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + + return inner + + +def log(t, eps: float = 1e-12): + return torch.log(t.clamp(min=eps)) + + +def right_pad_dims_to(x, t): + padding_dims = x.ndim - t.ndim + if padding_dims <= 0: + return t + return t.view(*t.shape, *((1,) * padding_dims)) + + +@torch.jit.script +def beta_linear_log_snr(t): + return -torch.log(expm1(1e-4 + 10 * (t ** 2))) + + +@torch.jit.script +def alpha_cosine_log_snr(t, s: float = 0.008): + return -log( + (torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps=1e-5 + ) # not sure if this accounts for beta being clipped to 0.999 in discrete version + + +def log_snr_to_alpha_sigma(log_snr): + return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) + + +class GaussianDiffusionContinuousTimes(nn.Module): + def __init__(self, *, noise_schedule, timesteps=1000, rng=None): + super().__init__() + + if noise_schedule == "linear": + self.log_snr = beta_linear_log_snr + elif noise_schedule == "cosine": + self.log_snr = alpha_cosine_log_snr + else: + raise ValueError(f'invalid noise schedule {noise_schedule}') + + self.num_timesteps = timesteps + self.rng = rng + + def get_times(self, batch_size, noise_level, *, device): + return torch.full((batch_size,), noise_level, device=device, dtype=torch.float32) + + def sample_random_times(self, batch_size, *, device): + return torch.rand((batch_size,), device=device, generator=self.rng, dtype=torch.float32) + + def get_condition(self, times): + return maybe(self.log_snr)(times) + + def get_sampling_timesteps(self, batch, *, device): + times = torch.linspace(1.0, 0.0, self.num_timesteps + 1, device=device) + times = repeat(times, 't -> b t', b=batch) + times = torch.stack((times[:, :-1], times[:, 1:]), dim=0) + times = times.unbind(dim=-1) + return times + + def q_posterior(self, x_start, x_t, t, *, t_next=None): + t_next = default(t_next, lambda: (t - 1.0 / self.num_timesteps).clamp(min=0.0)) + + """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ + log_snr = self.log_snr(t) + log_snr_next = self.log_snr(t_next) + log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) + + alpha, sigma = log_snr_to_alpha_sigma(log_snr) + alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) + + # c - as defined near eq 33 + c = -expm1(log_snr - log_snr_next) + posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) + + # following (eq. 33) + posterior_variance = (sigma_next ** 2) * c + posterior_log_variance_clipped = log(posterior_variance, eps=1e-20) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_sample(self, x_start, t, noise=None): + dtype = x_start.dtype + + if isinstance(t, float): + batch = x_start.shape[0] + t = torch.full((batch,), t, device=x_start.device, dtype=dtype) + + noise = default(noise, lambda: randn_like(x_start, generator=self.rng)) + log_snr = self.log_snr(t).type(dtype) + log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) + + return alpha * x_start + sigma * noise, log_snr, alpha, sigma + + def q_sample_from_to(self, x_from, from_t, to_t, noise=None): + shape, device, dtype = x_from.shape, x_from.device, x_from.dtype + batch = shape[0] + + if isinstance(from_t, float): + from_t = torch.full((batch,), from_t, device=device, dtype=dtype) + + if isinstance(to_t, float): + to_t = torch.full((batch,), to_t, device=device, dtype=dtype) + + noise = default(noise, lambda: randn_like(x_from, generator=self.rng)) + + log_snr = self.log_snr(from_t) + log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) + + log_snr_to = self.log_snr(to_t) + log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) + alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) + + return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha + + def predict_start_from_v(self, x_t, t, v): + log_snr = self.log_snr(t) + log_snr = right_pad_dims_to(x_t, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr) + return alpha * x_t - sigma * v + + def predict_start_from_noise(self, x_t, t, noise): + log_snr = self.log_snr(t) + log_snr = right_pad_dims_to(x_t, log_snr) + alpha, sigma = log_snr_to_alpha_sigma(log_snr) + return (x_t - sigma * noise) / alpha.clamp(min=1e-8) diff --git a/nemo/collections/multimodal/modules/imagen/sampler/sampler.py b/nemo/collections/multimodal/modules/imagen/sampler/sampler.py new file mode 100644 index 000000000000..2fd05faf814d --- /dev/null +++ b/nemo/collections/multimodal/modules/imagen/sampler/sampler.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch +from einops import rearrange +from tqdm import tqdm + +from nemo.collections.multimodal.modules.imagen.sampler.batch_ops import batch_div, batch_mul +from nemo.collections.multimodal.modules.imagen.sampler.continuous_ddpm import GaussianDiffusionContinuousTimes + + +def right_pad_dims_to(x, t): + padding_dims = x.ndim - t.ndim + if padding_dims <= 0: + return t + return t.view(*t.shape, *((1,) * padding_dims)) + + +def thresholding_x0(x0, method='dynamic', th=0.995): + if method is None: + return x0 + elif method == 'static': + return x0.clamp(-1.0, 1.0) + elif method == 'dynamic': + # torch.quantile only suppoprt either float or double dtype + # we need to manual cast it if running in FP16/AMP mode + original_dtype = x0.dtype + if original_dtype not in [torch.float, torch.double]: + x0 = x0.float() + s = torch.quantile(rearrange(x0, 'b ... -> b (...)').abs(), th, dim=-1) # From Figure A.10 (b) + s.clamp_(min=1.0) + s = right_pad_dims_to(x0, s) + x0 = x0.clamp(-s, s) / s + return x0.type(original_dtype) + else: + raise RuntimeError(f'Thresholding method: {method} not supported.') + + +def thresholding_derivative(x, t, d, thresholding_method='dynamic'): + x0 = x - batch_mul(d, t) + corrected_x0 = thresholding_x0(x0, thresholding_method) + corrected_d = batch_div(x - corrected_x0, t) + return corrected_d + + +class Sampler(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, model, model_kwargs, shape, z=None): + pass + + +class DDPMSampler(Sampler): + def __init__(self, unet_type, denoiser): + super().__init__() + self.unet_type = unet_type + self.noise_scheduler = denoiser + self.pred_objective = 'noise' + + def p_mean_variance( + self, unet, x, t, t_next, text_embeds, text_mask, x_low_res=None, cond_scale=1.0, thresholding_method='dynamic' + ): + + if self.unet_type == 'base': + pred = unet.forward_with_cond_scale( + x=x, time=t, text_embed=text_embeds, text_mask=text_mask, cond_scale=cond_scale + ) + elif self.unet_type == 'sr': + pred = unet.forward_with_cond_scale( + x=x, x_low_res=x_low_res, time=t, text_embed=text_embeds, text_mask=text_mask, cond_scale=cond_scale + ) + + if self.pred_objective == 'noise': + x_start = self.noise_scheduler.predict_start_from_noise(x, t=t, noise=pred) + elif self.pred_objective == 'x_start': + x_start = pred + elif self.pred_objective == 'v': + x_start = self.noise_scheduler.predict_start_from_v(x, t=t, v=pred) + else: + raise ValueError(f'unknown objective {self.pred_objective}') + + x_start = thresholding_x0(x_start, method=thresholding_method) + mean_and_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next) + return mean_and_variance, x_start + + @torch.no_grad() + def p_sample( + self, unet, x, t, t_next, text_embeds, text_mask, x_low_res=None, cond_scale=1.0, thresholding_method='dynamic' + ): + (model_mean, _, model_log_variance), x_start = self.p_mean_variance( + unet=unet, + x=x, + t=t, + t_next=t_next, + text_embeds=text_embeds, + text_mask=text_mask, + cond_scale=cond_scale, + x_low_res=x_low_res, + thresholding_method=thresholding_method, + ) + noise = torch.randn_like(x) + # no noise when t == 0 + b = x.shape[0] + is_last_sampling_timestep = ( + (t_next == 0) if isinstance(self.noise_scheduler, GaussianDiffusionContinuousTimes) else (t == 0) + ) + nonzero_mask = (1 - is_last_sampling_timestep.type_as(x)).reshape(b, *((1,) * (len(x.shape) - 1))) + pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return pred, x_start + + def forward( + self, + model, + noise_map, + text_encoding, + text_mask, + x_low_res=None, + cond_scale=1.0, + sampling_steps=None, + thresholding_method='dynamic', + ): + batch = noise_map.shape[0] + device = noise_map.device + dtype = noise_map.dtype + original_steps = self.noise_scheduler.num_timesteps + if sampling_steps: + self.noise_scheduler.num_timesteps = sampling_steps + timesteps = self.noise_scheduler.get_sampling_timesteps(batch, device=device) + img = noise_map + for times, times_next in tqdm(timesteps, total=len(timesteps)): + img, x_start = self.p_sample( + unet=model, + x=img.type(dtype), + t=times.type(dtype), + t_next=times_next.type(dtype), + text_embeds=text_encoding, + text_mask=text_mask, + cond_scale=cond_scale, + x_low_res=x_low_res.type(dtype) if x_low_res is not None else None, + thresholding_method=thresholding_method, + ) + self.noise_scheduler.num_timesteps = original_steps + return img + + +class EDMSampler(Sampler): + def __init__( + self, + unet_type, + num_steps=50, + sigma_min=0.002, + sigma_max=80, + rho=7, + S_churn=0, + S_min=0, + S_max=float('inf'), + S_noise=1, + ): + super().__init__() + self.unet_type = unet_type + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + self.S_churn = S_churn + self.S_min = S_min + self.S_max = S_max + self.S_noise = S_noise + self.num_steps = num_steps + + def forward( + self, + unet, + noise_map, + text_encoding, + text_mask, + x_low_res=None, + cond_scale=1.0, + sampling_steps=None, + thresholding_method='dynamic', + ): + if self.unet_type == 'base': + assert x_low_res is None + elif self.unet_type == 'sr': + assert x_low_res is not None + low_res_cond = {'x_low_res': x_low_res} if x_low_res is not None else {} + thresholding_method = 'dynamic' + sigma_min = self.sigma_min + sigma_max = self.sigma_max + print(f'Sampling with sigma in [{sigma_min}, {sigma_max}], cfg={cond_scale}') + # Time step discretization + num_steps = sampling_steps if sampling_steps else self.num_steps + step_indices = torch.arange(num_steps, device=noise_map.device) + # Table 1: Sampling - Time steps + t_steps = ( + sigma_max ** (1 / self.rho) + + step_indices / (num_steps - 1) * (sigma_min ** (1 / self.rho) - sigma_max ** (1 / self.rho)) + ) ** self.rho + t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0 + + # Main sampling loop. + x_next = noise_map * t_steps[0] + for i, (t_cur, t_next) in tqdm( + enumerate(zip(t_steps[:-1], t_steps[1:])), total=len(t_steps[:-1]) + ): # 0, ..., N-1 + x_cur = x_next + + # Increase noise temporarily. + gamma = min(self.S_churn / num_steps, np.sqrt(2) - 1) if self.S_min <= t_cur <= self.S_max else 0 + t_hat = (t_cur + gamma * t_cur).to(x_cur.device) + x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * self.S_noise * torch.randn_like(x_cur) + + # Euler step. + denoised = unet.forward_with_cond_scale( + x=x_hat.to(torch.float32), + time=t_hat.to(torch.float32), + text_embed=text_encoding, + text_mask=text_mask, + cond_scale=cond_scale, + **low_res_cond, + ) + d_cur = (x_hat - denoised) / t_hat + d_cur = thresholding_derivative(x_hat, t_hat, d_cur, thresholding_method=thresholding_method) + x_next = x_hat + (t_next - t_hat) * d_cur + + # Apply 2nd order correction. + if i < num_steps - 1: + denoised = unet.forward_with_cond_scale( + x=x_next.to(torch.float32), + time=t_next.to(torch.float32), + text_embed=text_encoding, + text_mask=text_mask, + cond_scale=cond_scale, + **low_res_cond, + ) + d_prime = (x_next - denoised) / t_next + d_prime = thresholding_derivative(x_next, t_next, d_prime, thresholding_method=thresholding_method) + x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) + return x_next diff --git a/nemo/collections/multimodal/modules/nerf/__init__.py b/nemo/collections/multimodal/modules/nerf/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/background/__init__.py b/nemo/collections/multimodal/modules/nerf/background/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/background/nerf_background_base.py b/nemo/collections/multimodal/modules/nerf/background/nerf_background_base.py new file mode 100644 index 000000000000..90b98d083b19 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/nerf_background_base.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + +# TODO(ahmadki): abstract class +class NeRFBackgroundBase(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, rays_d: torch.Tensor) -> torch.Tensor: + """ + positions = [B*N, 3] + """ + raise NotImplementedError + + def forward_net(self, rays_d_encoding: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def forward(self, rays_d: torch.Tensor) -> torch.Tensor: + rays_d_encoding = self.encode(rays_d) + features = self.forward_net(rays_d_encoding) + features = torch.sigmoid(features) + return features diff --git a/nemo/collections/multimodal/modules/nerf/background/random_background.py b/nemo/collections/multimodal/modules/nerf/background/random_background.py new file mode 100644 index 000000000000..2b725f6b7ffa --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/random_background.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from typing import Tuple + +import torch +import torch.nn as nn + + +class RandomBackground(nn.Module): + def __init__(self, base_background: Tuple, random_ratio: float) -> None: + super().__init__() + self.random_ratio = random_ratio + self.num_output_dims = len(base_background) + self.register_buffer("base_background", torch.tensor(base_background)) + + def forward(self, rays_d: torch.Tensor) -> torch.Tensor: + if random.random() < self.random_ratio: + return torch.rand(rays_d.shape[0], self.num_output_dims).to(rays_d) + else: + return self.base_background.to(rays_d).expand(rays_d.shape[0], -1) diff --git a/nemo/collections/multimodal/modules/nerf/background/static_background.py b/nemo/collections/multimodal/modules/nerf/background/static_background.py new file mode 100644 index 000000000000..a8ac33c61940 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/static_background.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch +import torch.nn as nn + + +class StaticBackground(nn.Module): + def __init__(self, background: Tuple) -> None: + super().__init__() + self.register_buffer("background", torch.tensor(background)) + + def forward(self, rays_d: torch.Tensor) -> torch.Tensor: + background = self.background.to(rays_d) + return background.expand(rays_d.shape[0], -1) diff --git a/nemo/collections/multimodal/modules/nerf/background/tcnn_background.py b/nemo/collections/multimodal/modules/nerf/background/tcnn_background.py new file mode 100644 index 000000000000..8ffc62303a6b --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/tcnn_background.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +import numpy as np +import tinycudann as tcnn +import torch + +from nemo.collections.multimodal.modules.nerf.background.nerf_background_base import NeRFBackgroundBase + + +class TCNNBackground(NeRFBackgroundBase): + def __init__( + self, + bound: int, + encoder_num_input_dims: int, + encoder_cfg: Dict, + background_net_num_output_dims: int, + background_net_cfg: Dict, + ): + super().__init__() + self.bound = bound + if encoder_cfg.get('per_level_scale') is None: + encoder_cfg['per_level_scale'] = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)) + self.encoder = tcnn.Encoding(n_input_dims=encoder_num_input_dims, encoding_config=dict(encoder_cfg)) + self.background_net = tcnn.Network( + self.encoder.n_output_dims, background_net_num_output_dims, network_config=dict(background_net_cfg) + ) + + def encode(self, rays_d: torch.Tensor) -> torch.Tensor: + return self.encoder(rays_d) + + def forward_net(self, rays_d_encoding: torch.Tensor) -> torch.Tensor: + return self.background_net(rays_d_encoding) diff --git a/nemo/collections/multimodal/modules/nerf/background/torchngp_background.py b/nemo/collections/multimodal/modules/nerf/background/torchngp_background.py new file mode 100644 index 000000000000..18a5c6d38a49 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/background/torchngp_background.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +import torch + +from nemo.collections.multimodal.modules.nerf.background.nerf_background_base import NeRFBackgroundBase +from nemo.collections.multimodal.modules.nerf.geometry.layers import MLP +from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.encoding import get_encoder + + +class TorchNGPBackground(NeRFBackgroundBase): + def __init__( + self, encoder_type: str, encoder_input_dims: int, encoder_multi_res: int, num_output_dims: int, net_cfg: Dict + ): + super().__init__() + + self.encoder, self.encoder_output_dims = get_encoder( + encoder_type, input_dim=encoder_input_dims, multires=encoder_multi_res + ) + self.background_net = MLP( + num_input_dims=self.encoder_output_dims, + num_output_dims=num_output_dims, + num_hidden_dims=net_cfg.num_hidden_dims, + num_layers=net_cfg.num_layers, + bias=net_cfg.bias, + ) + + def encode(self, rays_d: torch.Tensor) -> torch.Tensor: + return self.encoder(rays_d) + + def forward_net(self, rays_d_encoding: torch.Tensor) -> torch.Tensor: + return self.background_net(rays_d_encoding) diff --git a/nemo/collections/multimodal/modules/nerf/geometry/__init__.py b/nemo/collections/multimodal/modules/nerf/geometry/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/geometry/dmtet.py b/nemo/collections/multimodal/modules/nerf/geometry/dmtet.py new file mode 100644 index 000000000000..f6bd7700859e --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/dmtet.py @@ -0,0 +1,163 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +class DeepMarchingTetrahedra: + """ + Class for Deep Marching Tetrahedra (DMTet). + + Attributes: + device (torch.device): Device to place the tensors. + triangle_table (Tensor): Lookup table for the triangles. + num_triangles_table (Tensor): Table for the number of triangles. + base_tet_edges (Tensor): The base edges for the tetrahedrons. + """ + + def __init__(self, device: torch.device) -> None: + """Initialize DMTet instance with the given device. + + Args: + device (torch.device): The device to place the tensors on. + """ + self.device = device + self.triangle_table = self._create_triangle_table() + self.num_triangles_table = self._create_num_triangles_table() + self.base_tet_edges = self._create_base_tet_edges() + + def _create_triangle_table(self) -> torch.Tensor: + """Create the lookup table for triangles. + + Returns: + Tensor: The triangle lookup table. + """ + return torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], + ], + dtype=torch.long, + device=self.device, + ) + + def _create_num_triangles_table(self) -> torch.Tensor: + """Create the table for number of triangles. + + Returns: + Tensor: The number of triangles table. + """ + return torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=self.device) + + def _create_base_tet_edges(self) -> torch.Tensor: + """Create the base edges for the tetrahedrons. + + Returns: + Tensor: The base edges for tetrahedrons. + """ + return torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + + def _sort_edges(self, edges_ex2: torch.Tensor) -> torch.Tensor: + """Sort the given edges. + + Args: + edges_ex2 (Tensor): The edges to be sorted. + + Returns: + Tensor: The sorted edges. + """ + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + # TODO(ahmadki): rename to forward ? return mesh ? + def __call__(self, positions: torch.Tensor, sdf_n: torch.Tensor, tet_fx4: torch.Tensor) -> tuple: + """ + Process the provided data to generate vertices and faces. + + Args: + positions (Tensor): Position tensor with shape [N, 3]. + sdf_n (Tensor): SDF tensor with shape [N]. + tet_fx4 (Tensor): Tetrahedron faces tensor with shape [F, 4]. + + Returns: + tuple: Vertices and faces tensors. + """ + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) + all_edges = self._sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=self.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] + + edges_to_interp = positions[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=self.device)) + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = self.num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], + ).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], + dim=1, + index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], + ).reshape(-1, 3), + ), + dim=0, + ) + + return verts, faces diff --git a/nemo/collections/multimodal/modules/nerf/geometry/layers.py b/nemo/collections/multimodal/modules/nerf/geometry/layers.py new file mode 100644 index 000000000000..294bcfc427e2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/layers.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, List, Type, Union + +import torch +import torch.nn as nn + +BlockBuilder = Union[Callable[[int, int, bool], nn.Module], Type[nn.Module], None] + + +class MLP(nn.Module): + """ + A Multi-Layer Perceptron (MLP) module. + + Args: + num_input_dims (int): Number of input dimensions. + num_output_dims (int): Number of output dimensions. + num_hidden_dims (int): Number of hidden dimensions. + num_layers (int): Number of layers in the MLP. + bias (bool): If True, enables the bias in Linear layers. Default is True. + block (BlockBuilder): A callable or class for constructing a block. Default is None. + """ + + def __init__( + self, + num_input_dims: int, + num_output_dims: int, + num_hidden_dims: int, + num_layers: int, + bias: bool = True, + block: BlockBuilder = None, + ): + super().__init__() + + # Initialize the network as an empty list + network = [] + + # Add input layer + network.append(nn.Linear(num_input_dims, num_hidden_dims, bias=bias)) + network.append(nn.ReLU(inplace=True)) + + # Add hidden layers + for _ in range(1, num_layers - 1): + network.extend(self.build_layer(num_hidden_dims, num_hidden_dims, bias, block)) + + # Add output layer + network.append(nn.Linear(num_hidden_dims, num_output_dims, bias=bias)) + + # Wrap layers in ModuleList for proper registration + self.net = nn.ModuleList(network) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the MLP. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + for module in self.net: + x = module(x) + return x + + @staticmethod + def build_layer( + num_input_dims: int, num_output_dims: int, bias: bool = True, block_builder: BlockBuilder = None + ) -> List[nn.Module]: + """ + Build a single layer for the MLP. + + Args: + num_input_dims (int): Number of input dimensions. + num_output_dims (int): Number of output dimensions. + bias (bool): If True, enables the bias in Linear layers. Default is True. + block_builder (BlockBuilder): A callable or class for constructing a block. Default is None. + + Returns: + List[nn.Module]: A list containing the layer's modules. + """ + if block_builder is None: + return [nn.Linear(num_input_dims, num_output_dims, bias=bias), nn.ReLU(inplace=True)] + else: + return [block_builder(num_input_dims, num_output_dims, bias=bias)] + + +class ResBlock(nn.Module): + """ + A residual block module. + + Args: + num_input_dims (int): Number of input dimensions. + num_output_dims (int): Number of output dimensions. + bias (bool): If True, enables the bias in Linear layers. Default is True. + """ + + def __init__(self, num_input_dims: int, num_output_dims: int, bias: bool = True): + super().__init__() + + self.dense = nn.Linear(num_input_dims, num_output_dims, bias=bias) + self.norm = nn.LayerNorm(num_output_dims) + self.activation = nn.SiLU(inplace=True) + + if num_input_dims != num_output_dims: + self.skip = nn.Linear(num_input_dims, num_output_dims, bias=False) + else: + self.skip = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the residual block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + identity = x + + out = self.dense(x) + out = self.norm(out) + + if self.skip is not None: + identity = self.skip(identity) + + out += identity + out = self.activation(out) + + return out diff --git a/nemo/collections/multimodal/modules/nerf/geometry/nerf_base.py b/nemo/collections/multimodal/modules/nerf/geometry/nerf_base.py new file mode 100644 index 000000000000..c539a4c17771 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/nerf_base.py @@ -0,0 +1,373 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from typing import Optional, Tuple + +import mcubes +import numpy as np +import pymeshlab +import torch +import torch.nn as nn +import torch.nn.functional as F +import trimesh + +from nemo.collections.multimodal.modules.nerf.utils.activation import trunc_exp + + +class DensityActivationEnum(str, Enum): + EXP = "exp" + SOFTPLUS = "softplus" + + +class NormalTypeEnum(str, Enum): + AUTOGRAD = "autograd" + FORWARD_FINITE_DIFFERENCE = "forward_finite_difference" + BACKWARD_FINITE_DIFFERENCE = "backward_finite_difference" + CENTRAL_FINITE_DIFFERENCE = "central_finite_difference" + + +# TODO(ahmadki): make abstract +class NeRFBase(nn.Module): + """ + A base class for Neural Radiance Fields (NeRF) models. + + Args: + num_input_dims (int): Number of input dimensions. + bound (torch.Tensor): The bounding box tensor. + density_activation (DensityActivationEnum): Activation function for density. + blob_radius (float): Radius for the blob. + blob_density (float): Density for the blob. + normal_type (Optional[NormalTypeEnum]): Method to compute normals. + """ + + def __init__( + self, + num_input_dims: int, + bound: torch.Tensor, + density_activation: DensityActivationEnum, + blob_radius: float, + blob_density: float, + normal_type: Optional[NormalTypeEnum] = NormalTypeEnum.CENTRAL_FINITE_DIFFERENCE, + ) -> None: + super().__init__() + self.num_input_dims = num_input_dims + self.bound = bound + self.density_activation = density_activation + self.blob_radius = blob_radius + self.blob_density = blob_density + self.normal_type = normal_type + + def encode(self, positions: torch.Tensor) -> torch.Tensor: + """Encode 3D positions. To be implemented by subclasses.""" + raise NotImplementedError + + def sigma_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """Calculate sigma (density). To be implemented by subclasses.""" + raise NotImplementedError + + def features_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """Calculate features. To be implemented by subclasses.""" + raise NotImplementedError + + def forward( + self, positions: torch.Tensor, return_normal: bool = True + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Forward pass for the NeRF model. + + Args: + positions (torch.Tensor): The positions. + return_normal (bool): Flag to indicate whether to return normals or not. + + Returns: + Tuple containing density, features, and possibly normals. + """ + + if return_normal: + if self.normal_type == NormalTypeEnum.AUTOGRAD: + with torch.enable_grad(): + positions.requires_grad_(True) + sigma, features = self.forward_density_features(positions) + normal = -torch.autograd.grad(torch.sum(sigma), positions, create_graph=True)[0] # [N, D] + elif self.normal_type in [ + NormalTypeEnum.CENTRAL_FINITE_DIFFERENCE, + NormalTypeEnum.FORWARD_FINITE_DIFFERENCE, + NormalTypeEnum.BACKWARD_FINITE_DIFFERENCE, + ]: + sigma, features = self.forward_density_features(positions) + normal = self.normal_finite_differences(positions) + else: + raise NotImplementedError("Invalid normal type.") + + normal = F.normalize(normal) + normal = torch.nan_to_num(normal) + else: + sigma, features = self.forward_density_features(positions) + normal = None + + return sigma, features, normal + + def forward_density_features(self, positions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate both density and features based on the input positions. + + This function takes into account edge cases like empty input tensors and calculates + the density and features accordingly. See GitHub issues for details: + - https://github.com/KAIR-BAIR/nerfacc/issues/207#issuecomment-1653621720 + - https://github.com/ashawkey/torch-ngp/issues/176 + + Args: + positions (torch.Tensor): Input positions tensor with shape [B*N, D]. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing density and features tensors. + """ + + # Handle empty positions + if positions.shape[0] == 0: + sigma = torch.zeros(0, device=positions.device) + features = torch.zeros(0, self.num_input_dims, device=positions.device) + return sigma, features + + # Encode positions + positions_encoding = self.encode(positions) + + # Compute density + density = self.forward_density(positions, positions_encoding) + + # Compute features + features = self.forward_features(positions, positions_encoding) + + return density, features + + def forward_density( + self, positions: torch.Tensor, positions_encoding: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Calculate the density based on the input positions and their encoding. + + Args: + positions (torch.Tensor): Input positions tensor with shape [B*N, D]. + positions_encoding (Optional[torch.Tensor]): Optional encoded positions. + Will be computed from `positions` if not provided. + + Returns: + torch.Tensor: Density tensor. + """ + + # Handle empty positions + if positions.shape[0] == 0: + sigma = torch.zeros(0, device=positions.device) + return sigma + + # Compute encoded positions if not provided + if positions_encoding is None: + positions_encoding = self.encode(positions) + + # Compute sigma using the neural network + sigma = self.sigma_net(positions_encoding) + + # Compute density using activation function + if self.density_activation == DensityActivationEnum.EXP: + density = trunc_exp(sigma + self.density_blob(positions)) + elif self.density_activation == DensityActivationEnum.SOFTPLUS: + density = F.softplus(sigma + self.density_blob(positions)) + else: + raise NotImplementedError("Invalid density activation.") + + return density + + def forward_features( + self, positions: torch.Tensor, positions_encoding: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute the features based on the input positions and their encoding. + + Args: + positions (torch.Tensor): Input positions tensor with shape [B*N, D]. + positions_encoding (Optional[torch.Tensor]): Optional encoded positions. + Will be computed from `positions` if not provided. + + Returns: + torch.Tensor: Features tensor with shape [B*N, num_features_dims]. + """ + + # Handle empty positions + if positions.shape[0] == 0: + features = torch.zeros(0, self.num_features_dims, device=positions.device) + return features + + # Compute encoded positions if not provided + if positions_encoding is None: + positions_encoding = self.encode(positions) + + # Compute features using the neural network + features = self.features_net(positions_encoding) + + # Apply the sigmoid activation function to the features + features = torch.sigmoid(features) + + return features + + @torch.no_grad() + def density_blob(self, positions: torch.Tensor) -> torch.Tensor: + """ + Compute the density blob for the given positions. + + This method computes a density blob for each position in the tensor. It is + used to add a density value based on the distance of each position from the origin. + + Args: + positions (torch.Tensor): Input positions tensor with shape [B*N, D]. + + Returns: + torch.Tensor: Density blob tensor with shape [B*N, 1]. + """ + + # Compute the squared distance for each position + d = (positions ** 2).sum(-1) + + # Compute the density blob based on the activation function + if self.density_activation == DensityActivationEnum.EXP: + g = self.blob_density * torch.exp(-d / (2 * self.blob_radius ** 2)) + elif self.density_activation == DensityActivationEnum.SOFTPLUS: + g = self.blob_density * (1 - torch.sqrt(d) / self.blob_radius) + else: + raise NotImplementedError("Invalid density activation.") + + return g + + def normal_finite_differences(self, positions: torch.Tensor, eps: float = 1e-2) -> torch.Tensor: + """ + Calculate normals using finite differences. + + Args: + positions (torch.Tensor): Input positions tensor with shape [B*N, D]. + eps (float): A small value for finite difference calculation. Default is 1e-2. + + Returns: + torch.Tensor: Calculated normals tensor [B*N, D] + """ + # Create perturbation tensor + perturb = torch.eye(self.num_input_dims).to(positions.device).float() * eps # Shape (D, D) + + # Expand dims for batched operation + positions_expanded = positions[:, None, :] # (B*N, 1, D) + perturb_expanded = perturb[None, :, :] # (1, D, D) + + # Compute perturbed points + if self.normal_type == NormalTypeEnum.FORWARD_FINITE_DIFFERENCE: + positions_perturbed = positions_expanded + perturb_expanded # (B*N, D, D) + elif self.normal_type == NormalTypeEnum.BACKWARD_FINITE_DIFFERENCE: + positions_perturbed = positions_expanded - perturb_expanded # (B*N, D, D) + elif self.normal_type == NormalTypeEnum.CENTRAL_FINITE_DIFFERENCE: + positions_perturbed_pos = positions_expanded + perturb_expanded # (B*N, D, D) + positions_perturbed_neg = positions_expanded - perturb_expanded # (B*N, D, D) + positions_perturbed = torch.cat([positions_perturbed_pos, positions_perturbed_neg], dim=1) # (B*N, 2*D, D) + + # Reshape perturbed points for batched function call + positions_perturbed_reshaped = positions_perturbed.view(-1, self.num_input_dims) # (B*N * {D or 2*D}, D) + + # Evaluate function at perturbed points + perturbed_sigma = self.forward_density(positions_perturbed_reshaped) # (B*N * {D or 2*D}, 1) + + # Reshape function values + if self.normal_type == NormalTypeEnum.CENTRAL_FINITE_DIFFERENCE: + perturbed_sigma = perturbed_sigma.view(-1, 2 * self.num_input_dims) # (B*N, 2*D) + sigma_pos, sigma_neg = torch.chunk(perturbed_sigma, 2, dim=1) # (B*N, D) each + normal = 0.5 * (sigma_pos - sigma_neg) / eps # (B*N, D) + else: + perturbed_sigma = perturbed_sigma.view(-1, self.num_input_dims) # (B*N, D) + sigma = self.forward_density(positions) # (B*N,) # TODO(ahmadki): use the value from forward ? + if self.normal_type == NormalTypeEnum.FORWARD_FINITE_DIFFERENCE: + normal = (perturbed_sigma - sigma[:, None]) / eps # (B*N, D) + else: # self.normal_type == BACKWARD_FINITE_DIFFERENCE + normal = (sigma[:, None] - perturbed_sigma) / eps # (B*N, D) + + return -normal + + # TODO(ahmadki): needs ar ework: + # 1. texture/vertices are off-axis, needs a fix. + # 2. device='cuda' is hardcoded + # 3. DMTet needs to go through a different code path ? create a base volume nerf, and a base dmtet nerf class ? + @torch.no_grad() + def mesh( + self, resolution: Optional[int] = 128, batch_size: int = 128, density_thresh: Optional[float] = None + ) -> pymeshlab.Mesh: + """ + Generate a mesh from the nerf. + + Args: + resolution (Optional[int]): Resolution of the mesh grid. Default is 128. + batch_size (int): Batch size for the mesh generation. Default is 128. + density_thresh (Optional[float]): Density threshold for the mesh generation. Default is None, will be calculated from mean density. + + Returns: + pymeshlab.Mesh: Mesh object. + """ + # Generate a grid of 3D points + x = np.linspace(-self.bound, self.bound, resolution) + y = np.linspace(-self.bound, self.bound, resolution) + z = np.linspace(-self.bound, self.bound, resolution) + xx, yy, zz = np.meshgrid(x, y, z) + + grid = np.stack((xx, yy, zz), axis=-1) # Shape (resolution, resolution, resolution, 3) + torch_grid = torch.tensor(grid, dtype=torch.float32).reshape(-1, 3).to(device="cuda") + + def batch_process(fn, input, batch_size): + num_points = input.shape[0] + batches = [input[i : i + batch_size] for i in range(0, num_points, batch_size)] + results = [fn(batch) for batch in batches] + results = [result.detach().cpu().numpy() for result in results] + return np.concatenate(results, axis=0) + + density = batch_process(fn=self.forward_density, input=torch_grid, batch_size=batch_size) + density = density.reshape(resolution, resolution, resolution) + + # If not provided set density_thresh based on mean density + if density_thresh is None: + density_thresh = density[density > 1e-3].mean().item() + + # Apply Marching Cubes + vertices, triangles = mcubes.marching_cubes(density, density_thresh) + + # Create a new Mesh + ms = pymeshlab.MeshSet() + + # Create Mesh using vertices and faces + m = pymeshlab.Mesh(vertices.copy(), triangles.copy()) + + # Add mesh to the MeshSet + ms.add_mesh(m, "generated_mesh") + + # Filters + ms.meshing_remove_unreferenced_vertices() + ms.meshing_remove_duplicate_faces() + ms.meshing_remove_null_faces() + ms.meshing_repair_non_manifold_edges(method=0) + ms.meshing_repair_non_manifold_vertices(vertdispratio=0) + + m = ms.current_mesh() + vertices = m.vertex_matrix() + faces = m.face_matrix() + + scaled_vertice = ( + -self.bound + (vertices / resolution) * 2 * self.bound + ) # scale vertices back to [-self.bound, self.bound] + scaled_vertices_torch = torch.tensor(scaled_vertice, dtype=torch.float32).to(device="cuda") + color = batch_process(fn=self.forward_features, input=scaled_vertices_torch, batch_size=batch_size) + + # Create the final mesh from cleaned vertices and faces and with color + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=color) + return mesh diff --git a/nemo/collections/multimodal/modules/nerf/geometry/tcnn_nerf.py b/nemo/collections/multimodal/modules/nerf/geometry/tcnn_nerf.py new file mode 100644 index 000000000000..a7db0eea47f2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/tcnn_nerf.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional + +import numpy as np +import tinycudann as tcnn +import torch + +from nemo.collections.multimodal.modules.nerf.geometry.nerf_base import DensityActivationEnum, NeRFBase, NormalTypeEnum + + +# Don't fuse sigma_net with features_net: +# 1. performance benefit is questionable, especially that we sometimes require only density or features +# 2. we sacrifice generality +class TCNNNerf(NeRFBase): + """ + NeRF model with TCNN encoding and MLPs for sigma and features. + + Args: + num_input_dims (int): Number of input dimensions. + bound (torch.Tensor): The bounding box tensor. + density_activation (DensityActivationEnum): Activation function for density. + blob_radius (float): Radius for the blob. + blob_density (float): Density for the blob. + normal_type (Optional[NormalTypeEnum]): Method to compute normals. + encoder_cfg (Dict): Configuration for the TCNN encoder. + sigma_net_num_output_dims (int): Number of output dimensions for the sigma network. + sigma_net_cfg (Dict): Configuration for the sigma network. + features_net_num_output_dims (int): Number of output dimensions for the features network. + features_net_cfg (Optional[Dict]): Configuration for the features network. + """ + + def __init__( + self, + num_input_dims: int, + bound: torch.Tensor, + density_activation: DensityActivationEnum, + blob_radius: float, + blob_density: float, + normal_type: Optional[NormalTypeEnum], + encoder_cfg: Dict, + sigma_net_num_output_dims: int, + sigma_net_cfg: Dict, + features_net_num_output_dims: int, + features_net_cfg: Optional[Dict], + ) -> None: + super().__init__( + num_input_dims=num_input_dims, + bound=bound, + density_activation=density_activation, + blob_radius=blob_radius, + blob_density=blob_density, + normal_type=normal_type, + ) + + # Set per_level_scale if not set + if encoder_cfg.get('per_level_scale') is None: + encoder_cfg['per_level_scale'] = np.exp2(np.log2(2048 * self.bound / 16) / (16 - 1)) + # Build the TCNN encoder + self.encoder = tcnn.Encoding(n_input_dims=num_input_dims, encoding_config=dict(encoder_cfg)) + + # Build the sigma network + assert sigma_net_num_output_dims == 1, "sigma_net_num_output_dims!=1 is not supported" + self.sigma_tcnn = tcnn.Network( + self.encoder.n_output_dims, sigma_net_num_output_dims, network_config=dict(sigma_net_cfg) + ) + + # Build the features network + self.features_tcnn = None + if features_net_cfg is not None: + self.features_tcnn = tcnn.Network( + self.encoder.n_output_dims, features_net_num_output_dims, network_config=dict(features_net_cfg) + ) + + def encode(self, positions: torch.Tensor) -> torch.Tensor: + """ + Encode the positions using the TCNN encoder. + + Args: + positions (torch.Tensor): The positions tensor. + + Returns: + torch.Tensor: The encoded positions tensor. + """ + # TODO(ahmadki): is it safe to do with FP16 ? + return self.encoder((positions + self.bound) / (2 * self.bound)) + + def sigma_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """ + Compute the sigma using the TCNN network. + + Args: + positions_encoding (torch.Tensor): The encoded positions tensor. + + Returns: + torch.Tensor: The sigma tensor. + """ + return self.sigma_tcnn(positions_encoding).squeeze() + + def features_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """ + Compute the features using the TCNN network. + + Args: + positions_encoding (torch.Tensor): The encoded positions tensor. + + Returns: + torch.Tensor: The features tensor. + """ + return self.features_tcnn(positions_encoding) diff --git a/nemo/collections/multimodal/modules/nerf/geometry/torchngp_nerf.py b/nemo/collections/multimodal/modules/nerf/geometry/torchngp_nerf.py new file mode 100644 index 000000000000..4b1d5e34b22e --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/geometry/torchngp_nerf.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional + +import torch + +from nemo.collections.multimodal.modules.nerf.geometry.layers import MLP +from nemo.collections.multimodal.modules.nerf.geometry.nerf_base import DensityActivationEnum, NeRFBase, NormalTypeEnum +from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.encoding import get_encoder + + +# Don't fuse sigma_net with features_net: +# 1. performance benefit is questionable, especially that we sometimes require only density or features +# 2. we sacrifice generality +class TorchNGPNerf(NeRFBase): + """ + NeRF model with Torch-NGP encoding and MLPs for sigma and features. + + Args: + num_input_dims (int): Number of input dimensions. + bound (torch.Tensor): The bounding box tensor. + density_activation (DensityActivationEnum): Activation function for density. + blob_radius (float): Radius for the blob. + blob_density (float): Density for the blob. + normal_type (Optional[NormalTypeEnum]): Method to compute normals. + encoder_type (str): Type of the encoder. + encoder_max_level (int): Maximum level of the encoder. + sigma_net_num_output_dims (int): Number of output dimensions for the sigma network. + sigma_net_cfg (Dict): Configuration for the sigma network. + features_net_num_output_dims (int): Number of output dimensions for the features network. + features_net_cfg (Optional[Dict]): Configuration for the features network. + """ + + def __init__( + self, + num_input_dims: int, + bound: torch.Tensor, + density_activation: DensityActivationEnum, + blob_radius: float, + blob_density: float, + normal_type: Optional[NormalTypeEnum], + encoder_cfg: Dict, + sigma_net_num_output_dims: int, + sigma_net_cfg: Dict, + features_net_num_output_dims: int, + features_net_cfg: Optional[Dict], + ): + super().__init__( + num_input_dims=num_input_dims, + bound=bound, + density_activation=density_activation, + blob_radius=blob_radius, + blob_density=blob_density, + normal_type=normal_type, + ) + + # Build the Torch-NGP encoder + self.encoder_max_level = encoder_cfg.get('encoder_max_level', None) + self.encoder, self.encoder_output_dims = get_encoder(input_dim=num_input_dims, **encoder_cfg) + + # Build the sigma network + assert sigma_net_num_output_dims == 1, "sigma_net_num_output_dims must be equal to 1" + self.sigma_mlp = MLP( + num_input_dims=self.encoder_output_dims, + num_output_dims=sigma_net_num_output_dims, + num_hidden_dims=sigma_net_cfg.num_hidden_dims, + num_layers=sigma_net_cfg.num_layers, + bias=sigma_net_cfg.bias, + ) + + # Build the features network + self.features_mlp = None + if features_net_cfg is not None: + self.features_mlp = MLP( + num_input_dims=self.encoder_output_dims, + num_output_dims=features_net_num_output_dims, + num_hidden_dims=features_net_cfg.num_hidden_dims, + num_layers=features_net_cfg.num_layers, + bias=features_net_cfg.bias, + ) + + def encode(self, positions: torch.Tensor) -> torch.Tensor: + """ + Encode the positions. + + Args: + positions (torch.Tensor): The positions tensor. + + Returns: + torch.Tensor: The encoded positions tensor. + """ + return self.encoder(positions, bound=self.bound, max_level=self.encoder_max_level) + + def sigma_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """ + Compute the sigma using the sigma network. + + Args: + positions_encoding (torch.Tensor): The encoded positions tensor. + + Returns: + torch.Tensor: The sigma tensor. + """ + return self.sigma_mlp(positions_encoding).squeeze() + + def features_net(self, positions_encoding: torch.Tensor) -> torch.Tensor: + """ + Compute the features using the features network. + + Args: + positions_encoding (torch.Tensor): The encoded positions tensor. + + Returns: + torch.Tensor: The features tensor. + """ + return self.features_mlp(positions_encoding) diff --git a/nemo/collections/multimodal/modules/nerf/guidance/__init__.py b/nemo/collections/multimodal/modules/nerf/guidance/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/guidance/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_huggingface_pipeline.py b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_huggingface_pipeline.py new file mode 100644 index 000000000000..ed5a2fd1fa1e --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_huggingface_pipeline.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch +import torch.nn.functional as F +from diffusers import DDIMScheduler, StableDiffusionPipeline + +from nemo.collections.multimodal.modules.nerf.guidance.txt2img_guidance_base import Txt2ImgGuidanceBase + + +class StableDiffusion(Txt2ImgGuidanceBase): + def __init__( + self, + model_key: str = "stabilityai/stable-diffusion-2-1-base", + t_range: List[float] = [0.02, 0.98], + precision: str = "16", + device: torch.device = torch.device('cuda'), + ): + """ + Initialize StableDiffusion with model_key, t_range, precision and device. + + Parameters: + model_key (str): Pre-trained model key. + t_range (List[float]): Range for timesteps. + precision (str): Model precision ("16", "bf16" or other for float32). + device (torch.device): Device for torch tensor. + """ + super().__init__() + + self.device = device + self.model_key = model_key + self.precision_t = self._get_precision_type(precision) + + # Create model + pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t).to(self.device) + if self.precision_t in [torch.float16, torch.bfloat16]: + pipe.unet.to(memory_format=torch.channels_last) + + self.vae = pipe.vae + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t) + + del pipe + + self.num_train_timesteps = self.scheduler.config.num_train_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.scheduler.alphas_cumprod.to(self.device) + + def _get_precision_type(self, precision: str) -> torch.dtype: + """ + Map string precision representation to torch dtype. + + Parameters: + precision (str): String representation of precision. + + Returns: + torch.dtype: Corresponding torch dtype. + """ + precision_map = {"16": torch.float16, "bf16": torch.bfloat16} + return precision_map.get(precision, torch.float32) + + @torch.no_grad() + def get_text_embeds(self, prompt: str) -> torch.Tensor: + """ + Get text embeddings from the given prompt. + + Parameters: + prompt (str): Input text. + + Returns: + torch.Tensor: Text embeddings tensor [B, 77, 1024]. + """ + inputs = self.tokenizer( + prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt' + ) + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + return embeddings + + # @torch.compile() # TODO(ahmadki) + def train_step( + self, + text_embeddings: torch.Tensor, + pred_rgb: torch.Tensor, + guidance_scale: float = 100.0, + as_latent: bool = False, + ) -> float: + """ + Train step function for StableDiffusion. + + Parameters: + text_embeddings (torch.Tensor): Embeddings tensor [B, 512]. + pred_rgb (torch.Tensor): Predicted RGB tensor [B, 3, 512, 512]. + guidance_scale (float): Guidance scaling factor. + as_latent (bool): If True, considers pred_rgb as latent. + + Returns: + float: Loss value. + """ + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + latents = self.encode_imgs(pred_rgb_512) + + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + td = torch.cat([t] * 2) + noise_pred = self.unet(latent_model_input, td, encoder_hidden_states=text_embeddings).sample + + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + w = 1 - self.alphas[t] + grad = w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + return loss + + def encode_imgs(self, imgs: torch.Tensor) -> torch.Tensor: + """ + Encode images into latent representations. + + Parameters: + imgs (torch.Tensor): Image tensor [B, 3, H, W]. + + Returns: + torch.Tensor: Encoded latent tensor. + """ + imgs = 2 * imgs - 1 + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * self.vae.config.scaling_factor + return latents diff --git a/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_nemo_pipeline.py b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_nemo_pipeline.py new file mode 100644 index 000000000000..6c2f96dc2dde --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_nemo_pipeline.py @@ -0,0 +1,141 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile + +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import LatentDiffusion +from nemo.collections.multimodal.modules.nerf.guidance.txt2img_guidance_base import Txt2ImgGuidanceBase +from nemo.collections.multimodal.modules.stable_diffusion.distributions.distributions import ( + DiagonalGaussianDistribution, +) +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + + +class StableDiffusion(Txt2ImgGuidanceBase): + def __init__( + self, checkpoint, sampler_type="DDIM", t_range=[0.02, 0.98], precision="16", device=torch.device('cuda') + ): + super().__init__() + + self.device = device + self.checkpoint = checkpoint + self.sampler_type = sampler_type + + cfg, state_dict = self.load_config_and_state_from_nemo(checkpoint) + + cfg.precision = precision + cfg.ckpt_path = None + cfg.unet_config.from_pretrained = None + cfg.first_stage_config.from_pretrained = None + + self.model = LatentDiffusion(cfg).to(device) + + sd_state_dict = {} + # Remove Megatron wrapper and inductor + for key, value in state_dict.items(): + key = key[6:] + sd_state_dict[key] = value + self.model.load_state_dict(sd_state_dict) + self.first_stage_model = self.model.first_stage_model + self.text_encoder = self.model.cond_stage_model.encode + + self.num_train_timesteps = self.model.num_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.model.alphas_cumprod.to(self.device) + + @torch.no_grad() + def get_text_embeds(self, prompt): + return self.text_encoder(prompt) + + @torch.autocast(device_type="cuda") + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False): + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.model.q_sample(x_start=latents, t=t, noise=noise) + latent_model_input = torch.cat([latents_noisy] * 2) + td = torch.cat([t] * 2) + noise_pred = self.model.apply_model(latent_model_input, td, text_embeddings) + + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + # w(t), sigma_t^2 + w = 1 - self.alphas[t] + grad = w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + return loss + + def image_encoder(self, x): + h = self.first_stage_model.encoder(x) + moments = self.first_stage_model.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def encode_imgs(self, imgs): + # imgs: [B, 3, H, W] + + imgs = 2 * imgs - 1 + + posterior = self.image_encoder(imgs) + latents = ( + posterior.sample() * self.image_encoder.config.scaling_factor + ) # self.vae.config.scaling_factor==0.18215 + + return latents + + def load_config_and_state_from_nemo(self, nemo_path): + if torch.cuda.is_available(): + map_location = torch.device('cuda') + else: + map_location = torch.device('cpu') + save_restore_connector = NLPSaveRestoreConnector() + cwd = os.getcwd() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + save_restore_connector._unpack_nemo_file(path2file=nemo_path, out_folder=tmpdir) + + # Change current working directory to + os.chdir(tmpdir) + config_yaml = os.path.join(tmpdir, save_restore_connector.model_config_yaml) + cfg = OmegaConf.load(config_yaml) + + model_weights = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + state_dict = save_restore_connector._load_state_dict_from_disk( + model_weights, map_location=map_location + ) + finally: + os.chdir(cwd) + + return cfg, state_dict diff --git a/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_trt_pipeline.py b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_trt_pipeline.py new file mode 100644 index 000000000000..884c86254ab5 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/guidance/stablediffusion_trt_pipeline.py @@ -0,0 +1,234 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import tempfile + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from omegaconf import OmegaConf +from polygraphy import cuda +from transformers import CLIPTokenizer + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import LatentDiffusion +from nemo.collections.multimodal.modules.nerf.guidance.txt2img_guidance_base import Txt2ImgGuidanceBase +from nemo.collections.multimodal.modules.nerf.utils.trt_engine import Engine, device_view +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + extract_into_tensor, + make_beta_schedule, +) +from nemo.collections.multimodal.parts.stable_diffusion.utils import default +from nemo.collections.multimodal.parts.utils import randn_like +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + + +class LatentDiffusionWrapper(Txt2ImgGuidanceBase): + def __init__(self, plan_dir, checkpoint): + super().__init__() + with open(os.path.join(plan_dir, "conf.yaml"), "rb") as fp: + config = OmegaConf.load(fp.name) + max_batch_size = config.batch_size + + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.max_length = config.clip.max_length + self.rng = torch.Generator(device=torch.cuda.current_device(),) + + self.set_beta_schedule() + + stream = cuda.Stream() + + self.image_encoder = self.load_vae_from_checkpoint(checkpoint) + + self.text_encoder = Engine(os.path.join(plan_dir, "clip.plan")) + shape_dict = {'tokens': config.clip.tokens, 'logits': config.clip.logits} + self.text_encoder.set_engine(stream, shape_dict) + + self.unet = Engine(os.path.join(plan_dir, "unet.plan")) + shape_dict = { + 'x': config.unet.x, + 't': (max_batch_size * 2,), + 'context': config.unet.context, + 'logits': config.unet.logits, + } + self.unet.set_engine(stream, shape_dict) + + def set_beta_schedule(self): + betas = make_beta_schedule("linear", 1000, linear_start=0.00085, linear_end=0.0120, cosine_s=0.008) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + betas = torch.tensor(betas) + alphas = torch.tensor(alphas) + alphas_cumprod = torch.tensor(alphas_cumprod) + to_torch = lambda x: x.clone().detach().to(torch.float32).to(torch.cuda.current_device()) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod.cpu()))) + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: randn_like(x_start, generator=self.rng)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def encode_imgs(self, imgs): + imgs = 2 * imgs - 1 + posterior = self.image_encoder(imgs) + latents = posterior.sample() * 0.18215 + return latents + + def clip_encode(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to("cuda", non_blocking=True) + z = self.text_encoder.infer({"tokens": device_view(tokens.type(torch.int32))})['logits'].clone() + seq_len = (z.shape[1] + 8 - 1) // 8 * 8 + z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0) + return z + + def apply_model(self, x, t, cond, return_ids=False): + self.conditioning_key = "crossattn" + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + # key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + key = 'c_crossattn' + cond = {key: cond} + # UNET TRT + cc = torch.cat(cond['c_crossattn'], 1) # needs to be changed I think + out = self.unet.infer( + { + "x": device_view(x.contiguous()), + "t": device_view(t.type(torch.int32).contiguous()), + "context": device_view(cc.contiguous()), + } + )['logits'].clone() + if isinstance(out, tuple) and not return_ids: + return out[0] + else: + return out + + def load_vae_from_checkpoint(self, checkpoint): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + cfg, state_dict = self.load_config_and_state_from_nemo(checkpoint) + + if cfg.get('unet_config') and cfg.get('unet_config').get('from_pretrained'): + cfg.unet_config.from_pretrained = None + if cfg.get('first_stage_config') and cfg.get('first_stage_config').get('from_pretrained'): + cfg.first_stage_config.from_pretrained = None + + model = LatentDiffusion(cfg).to(device) + + sd_state_dict = {} + for key, value in state_dict.items(): + key = key[6:] + sd_state_dict[key] = value + model.load_state_dict(sd_state_dict) + + return model.first_stage_model.encode + + def load_config_and_state_from_nemo(self, nemo_path): + if torch.cuda.is_available(): + map_location = torch.device('cuda') + else: + map_location = torch.device('cpu') + save_restore_connector = NLPSaveRestoreConnector() + cwd = os.getcwd() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + save_restore_connector._unpack_nemo_file(path2file=nemo_path, out_folder=tmpdir) + + # Change current working directory to + os.chdir(tmpdir) + config_yaml = os.path.join(tmpdir, save_restore_connector.model_config_yaml) + cfg = OmegaConf.load(config_yaml) + + model_weights = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + state_dict = save_restore_connector._load_state_dict_from_disk( + model_weights, map_location=map_location + ) + finally: + os.chdir(cwd) + + return cfg, state_dict + + +class StableDiffusion(nn.Module): + def __init__(self, plan_dir, checkpoint, sampler_type="DDIM", t_range=[0.02, 0.98], device=torch.device('cuda')): + super().__init__() + logging.info(f'loading stable diffusion...') + + self.device = device + self.sampler_type = sampler_type + self.model = LatentDiffusionWrapper(plan_dir, checkpoint) + + self.text_encoder = self.model.clip_encode + + self.num_train_timesteps = self.model.num_timesteps + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.alphas = self.model.alphas_cumprod.to(self.device) # for convenience + + logging.info(f'loaded stable diffusion!') + + @torch.no_grad() + def get_text_embeds(self, prompt): + return self.text_encoder(prompt) + + def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, as_latent=False): + + if as_latent: + latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1 + else: + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) + # encode image into latents with vae, requires grad! + latents = self.model.encode_imgs(pred_rgb_512) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device) + + with torch.no_grad(): + noise = torch.randn_like(latents) + latents_noisy = self.model.q_sample(x_start=latents, t=t, noise=noise) + latent_model_input = torch.cat([latents_noisy] * 2) + td = torch.cat([t] * 2) + noise_pred = self.model.apply_model(latent_model_input, td, text_embeddings) + + noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond) + + # w(t), sigma_t^2 + w = 1 - self.alphas[t] + grad = w[:, None, None, None] * (noise_pred - noise) + grad = torch.nan_to_num(grad) + + targets = (latents - grad).detach() + loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0] + return loss diff --git a/nemo/collections/multimodal/modules/nerf/guidance/txt2img_guidance_base.py b/nemo/collections/multimodal/modules/nerf/guidance/txt2img_guidance_base.py new file mode 100644 index 000000000000..db82584ba78e --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/guidance/txt2img_guidance_base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn + + +class Txt2ImgGuidanceBase(nn.Module): + def __init__(self): + super().__init__() diff --git a/nemo/collections/multimodal/modules/nerf/loss/__init__.py b/nemo/collections/multimodal/modules/nerf/loss/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/loss/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/loss/laplacian_smooth_loss.py b/nemo/collections/multimodal/modules/nerf/loss/laplacian_smooth_loss.py new file mode 100644 index 000000000000..93b93981304c --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/loss/laplacian_smooth_loss.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + + +class LaplacianSmoothLoss(nn.Module): + def __init__(self): + super(LaplacianSmoothLoss, self).__init__() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, verts, faces): + with torch.no_grad(): + L = self.laplacian_uniform(verts, faces.long()) + loss = L.mm(verts) + loss = loss.norm(dim=1) + loss = loss.mean() + return loss + + # TODO(ahmadki): should be moved to a separate mesh class + def laplacian_uniform(self, verts, faces): + V = verts.shape[0] + F = faces.shape[0] + + # Neighbor indices + ii = faces[:, [1, 2, 0]].flatten() + jj = faces[:, [2, 0, 1]].flatten() + adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(dim=1) + adj_values = torch.ones(adj.shape[1], device=verts.device, dtype=torch.float) + + # Diagonal indices + diag_idx = adj[0] + + # Build the sparse matrix + idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1) + values = torch.cat((-adj_values, adj_values)) + + # The coalesce operation sums the duplicate indices, resulting in the + # correct diagonal + return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce() diff --git a/nemo/collections/multimodal/modules/nerf/loss/normal_consistency_loss.py b/nemo/collections/multimodal/modules/nerf/loss/normal_consistency_loss.py new file mode 100644 index 000000000000..ef0c31da4783 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/loss/normal_consistency_loss.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + + +class NormalConsistencyLoss(nn.Module): + def __init__(self): + super(NormalConsistencyLoss, self).__init__() + + # TODO(ahmadki): is this safe to do in FP16 ? + def forward(self, face_normals, t_pos_idx): + tris_per_edge = self.compute_edge_to_face_mapping(t_pos_idx) + + # Fetch normals for both faces sharind an edge + n0 = face_normals[tris_per_edge[:, 0], :] + n1 = face_normals[tris_per_edge[:, 1], :] + + # Compute error metric based on normal difference + term = torch.clamp(torch.sum(n0 * n1, -1, keepdim=True), min=-1.0, max=1.0) + term = 1.0 - term + + return torch.mean(torch.abs(term)) + + # TODO(ahmadki): should belog to mesh class + def compute_edge_to_face_mapping(self, attr_idx): + with torch.no_grad(): + # Get unique edges + # Create all edges, packed by triangle + all_edges = torch.cat( + ( + torch.stack((attr_idx[:, 0], attr_idx[:, 1]), dim=-1), + torch.stack((attr_idx[:, 1], attr_idx[:, 2]), dim=-1), + torch.stack((attr_idx[:, 2], attr_idx[:, 0]), dim=-1), + ), + dim=-1, + ).view(-1, 2) + + # Swap edge order so min index is always first + order = (all_edges[:, 0] > all_edges[:, 1]).long().unsqueeze(dim=1) + sorted_edges = torch.cat( + (torch.gather(all_edges, 1, order), torch.gather(all_edges, 1, 1 - order)), dim=-1 + ) + + # Elliminate duplicates and return inverse mapping + unique_edges, idx_map = torch.unique(sorted_edges, dim=0, return_inverse=True) + + tris = torch.arange(attr_idx.shape[0]).repeat_interleave(3).cuda() + + tris_per_edge = torch.zeros((unique_edges.shape[0], 2), dtype=torch.int64).cuda() + + # Compute edge to face table + mask0 = order[:, 0] == 0 + mask1 = order[:, 0] == 1 + tris_per_edge[idx_map[mask0], 0] = tris[mask0] + tris_per_edge[idx_map[mask1], 1] = tris[mask1] + + return tris_per_edge diff --git a/nemo/collections/multimodal/modules/nerf/materials/__init__.py b/nemo/collections/multimodal/modules/nerf/materials/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/materials/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/materials/basic_shading.py b/nemo/collections/multimodal/modules/nerf/materials/basic_shading.py new file mode 100644 index 000000000000..45d41b262ecb --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/materials/basic_shading.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch + +from nemo.collections.multimodal.modules.nerf.materials.materials_base import MaterialsBase, ShadingEnum + + +class BasicShading(MaterialsBase): + """ + Material model for handling various shading types. + """ + + def __init__(self): + super(BasicShading, self).__init__() + self.specular = torch.nn.Parameter(torch.rand(3)) + self.shininess = torch.nn.Parameter(torch.rand(1)) + + def forward( + self, + albedo: torch.Tensor, + normals: torch.Tensor, + light_d: torch.Tensor, + ambient_ratio: float, + shading_type: Optional[ShadingEnum] = None, + ) -> torch.Tensor: + """ + Apply material and shading to the input RGB tensor. + + Args: + albedo (Tensor): Base albedo values. + normals (Tensor): Normal vectors at each ray intersection. + light_d (Tensor): Light direction. + ambient_ratio (float): Ratio for ambient lighting. + shading_type (ShadingEnum): The type of shading to apply + + Returns: + Tensor: The output RGB tensor after applying material and shading. + """ + if shading_type is None: + return albedo + elif shading_type == ShadingEnum.TEXTURELESS: + return torch.ones_like(albedo) * ambient_ratio + elif shading_type == ShadingEnum.NORMAL: + return (normals + 1) / 2 # Map normals from [-1, 1] to [0, 1] + elif shading_type in [ShadingEnum.LAMBERTIAN, ShadingEnum.PHONG]: + # Ambient light + ambient_light = ambient_ratio * albedo + # Dot product between light direction and normals + dot_product = torch.sum(normals * light_d, dim=1, keepdim=True) + # Lambertian term + diffuse_term = albedo * torch.clamp(dot_product, min=0) + + if shading_type == ShadingEnum.LAMBERTIAN: + return ambient_light + diffuse_term + elif shading_type == ShadingEnum.PHONG: + # Phong specular term + specular_term = ( + self.specular + * (self.shininess + 2) + * torch.pow(torch.clamp(dot_product, min=0), self.shininess) + / (2 * 3.14159) + ) + + return ambient_light + diffuse_term + specular_term + else: + raise ValueError(f"Unknown shading_type: {shading_type}") diff --git a/nemo/collections/multimodal/modules/nerf/materials/materials_base.py b/nemo/collections/multimodal/modules/nerf/materials/materials_base.py new file mode 100644 index 000000000000..be8e81682a5f --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/materials/materials_base.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum + +from torch import nn + + +class ShadingEnum(str, Enum): + TEXTURELESS = "textureless" + NORMAL = "normal" + LAMBERTIAN = "lambertian" + PHONG = "phong" + + # TODO(ahmadki): + # Oren–Nayar + # Minnaert + # Cook–Torrance + # Ward anisotropic + # Hanrahan–Krueger + # Cel shading + # Gooch shading + + +class MaterialsBase(nn.Module): + """ + Base class for materials. + """ + + def __init__(self): + super(MaterialsBase, self).__init__() diff --git a/nemo/collections/multimodal/modules/nerf/renderers/__init__.py b/nemo/collections/multimodal/modules/nerf/renderers/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/renderers/base_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/base_renderer.py new file mode 100644 index 000000000000..61753bc088cd --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/base_renderer.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + +# TODO(ahmadki): make abstract +class BaseRenderer(nn.Module): + def __init__(self, bound, update_interval): + super().__init__() + self.bound = bound + aabb = torch.FloatTensor([-bound, -bound, -bound, bound, bound, bound]) + self.register_buffer('aabb', aabb) + self.update_interval = update_interval + + @torch.no_grad() + def update_step(self, epoch: int, global_step: int, decay: float = 0.95, **kwargs): + raise NotImplementedError + + def forward(self, rays_o, rays_d, return_normal_image=False, return_normal_perturb=False, **kwargs): + raise NotImplementedError diff --git a/nemo/collections/multimodal/modules/nerf/renderers/base_sdf_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/base_sdf_renderer.py new file mode 100644 index 000000000000..48450fc311ba --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/base_sdf_renderer.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from nemo.collections.multimodal.modules.renderer.base_renderer import RendererBase + + +class BaseSDFRenderer(RendererBase): + def __init__(self, bound): + super().__init__(bound) + + # TODO(ahmadki): needs a rework + @torch.no_grad() + def get_vertices_and_triangles(self, resolution=None, S=128): + deform = torch.tanh(self.deform) / self.grid_size + + vertices, triangles = self.dmtet(self.verts + deform, self.sdf, self.indices) + + vertices = vertices.detach().cpu().numpy() + triangles = triangles.detach().cpu().numpy() + + return vertices, triangles diff --git a/nemo/collections/multimodal/modules/nerf/renderers/base_volume_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/base_volume_renderer.py new file mode 100644 index 000000000000..4801b0e9e5f3 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/base_volume_renderer.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo.collections.multimodal.modules.renderer.base_renderer import RendererBase + + +class BaseVolumeRenderer(RendererBase): + def __init__(self, bound, update_interval): + super().__init__(bound, update_interval) diff --git a/nemo/collections/multimodal/modules/nerf/renderers/nerfacc_volume_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/nerfacc_volume_renderer.py new file mode 100644 index 000000000000..3bf74b8fa5dd --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/nerfacc_volume_renderer.py @@ -0,0 +1,376 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +from typing import Optional + +import torch +from nerfacc.estimators.occ_grid import OccGridEstimator +from nerfacc.grid import ray_aabb_intersect, traverse_grids +from nerfacc.volrend import accumulate_along_rays_, render_weight_from_density, rendering + +from nemo.collections.multimodal.modules.renderer.base_renderer import BaseRenderer + +Rays = collections.namedtuple("Rays", ("origins", "viewdirs")) + + +def namedtuple_map(fn, tup): + """Apply `fn` to each element of `tup` and cast to `tup`'s namedtuple.""" + return type(tup)(*(None if x is None else fn(x) for x in tup)) + + +def render_image_with_occgrid( + # scene + nerf: torch.nn.Module, + estimator: OccGridEstimator, + rays: Rays, + # rendering options + near_plane: float = 0.0, + far_plane: float = 1e10, + render_step_size: float = 1e-3, + render_bkgd: Optional[torch.Tensor] = None, + cone_angle: float = 0.0, + alpha_thre: float = 0.0, + # test options + test_chunk_size: int = 8192, +): + """Render the pixels of an image.""" + rays_shape = rays.origins.shape + if len(rays_shape) == 3: + height, width, _ = rays_shape + num_rays = height * width + rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays) + else: + num_rays, _ = rays_shape + + # TODO(ahmadki): optimize, cache result between sigma_fn and rgb_sigma_fn + def sigma_fn(t_starts, t_ends, ray_indices): + t_origins = chunk_rays.origins[ray_indices] + t_dirs = chunk_rays.viewdirs[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas = nerf.density(positions)['sigma'] + return sigmas + + def rgb_sigma_fn(t_starts, t_ends, ray_indices): + t_origins = chunk_rays.origins[ray_indices] + t_dirs = chunk_rays.viewdirs[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas, rgbs, normal = nerf( + positions=positions, view_dirs=None, light_dirs=t_dirs + ) # TODO(ahmadki): t_dirs is incorrect + return rgbs, sigmas + + results = [] + chunk = torch.iinfo(torch.int32).max if nerf.training else test_chunk_size + + for i in range(0, num_rays, chunk): + chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) + ray_indices, t_starts, t_ends = estimator.sampling( + chunk_rays.origins, + chunk_rays.viewdirs, + sigma_fn=sigma_fn, + near_plane=near_plane, + far_plane=far_plane, + render_step_size=render_step_size, + stratified=nerf.training, + cone_angle=cone_angle, + alpha_thre=alpha_thre, + ) + rgb, opacity, depth, extras = rendering( + t_starts, + t_ends, + ray_indices, + n_rays=chunk_rays.origins.shape[0], + rgb_sigma_fn=rgb_sigma_fn, + render_bkgd=render_bkgd, + ) + + weight = extras["weights"] + alpha = extras["alphas"] + + chunk_results = [rgb, opacity, depth, weight, alpha, len(t_starts)] + results.append(chunk_results) + + colors, opacities, depths, weights, alphas, n_rendering_samples = [ + torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r for r in zip(*results) + ] + + return ( + colors.view((*rays_shape[:-1], -1)), + opacities.view((*rays_shape[:-1], -1)), + depths.view((*rays_shape[:-1], -1)), + weights, + alphas, + sum(n_rendering_samples), + ) + + +@torch.no_grad() +def render_image_with_occgrid_test( + max_samples: int, + # scene + nerf: torch.nn.Module, + estimator: OccGridEstimator, + rays: Rays, + # rendering options + near_plane: float = 0.0, + far_plane: float = 1e10, + render_step_size: float = 1e-3, + render_bkgd: Optional[torch.Tensor] = None, + cone_angle: float = 0.0, + alpha_thre: float = 0.0, + early_stop_eps: float = 1e-4, +): + """Render the pixels of an image.""" + rays_shape = rays.origins.shape + if len(rays_shape) == 3: + height, width, _ = rays_shape + num_rays = height * width + rays = namedtuple_map(lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays) + else: + num_rays, _ = rays_shape + + def rgb_sigma_fn(t_starts, t_ends, ray_indices): + t_origins = rays.origins[ray_indices] + t_dirs = rays.viewdirs[ray_indices] + positions = t_origins + t_dirs * (t_starts[:, None] + t_ends[:, None]) / 2.0 + sigmas, rgbs, normal = nerf( + positions=positions, view_dirs=None, light_dirs=t_dirs + ) # TODO(ahmadki): t_dirs is incorrect ? + return rgbs, sigmas + + device = rays.origins.device + opacity = torch.zeros(num_rays, 1, device=device) + depth = torch.zeros(num_rays, 1, device=device) + rgb = torch.zeros(num_rays, 3, device=device) + + ray_mask = torch.ones(num_rays, device=device).bool() + + # 1 for synthetic scenes, 4 for real scenes + min_samples = 1 if cone_angle == 0 else 4 + + iter_samples = total_samples = 0 + + rays_o = rays.origins + rays_d = rays.viewdirs + + near_planes = torch.full_like(rays_o[..., 0], fill_value=near_plane) + far_planes = torch.full_like(rays_o[..., 0], fill_value=far_plane) + + t_mins, t_maxs, hits = ray_aabb_intersect(rays_o, rays_d, estimator.aabbs) + + n_grids = estimator.binaries.size(0) + + if n_grids > 1: + t_sorted, t_indices = torch.sort(torch.cat([t_mins, t_maxs], -1), -1) + else: + t_sorted = torch.cat([t_mins, t_maxs], -1) + t_indices = torch.arange(0, n_grids * 2, device=t_mins.device, dtype=torch.int64).expand(num_rays, n_grids * 2) + + opc_thre = 1 - early_stop_eps + + while iter_samples < max_samples: + + n_alive = ray_mask.sum().item() + if n_alive == 0: + break + + # the number of samples to add on each ray + n_samples = max(min(num_rays // n_alive, 64), min_samples) + iter_samples += n_samples + + # ray marching + (intervals, samples, termination_planes) = traverse_grids( + # rays + rays_o, # [n_rays, 3] + rays_d, # [n_rays, 3] + # grids + estimator.binaries, # [m, resx, resy, resz] + estimator.aabbs, # [m, 6] + # options + near_planes, # [n_rays] + far_planes, # [n_rays] + render_step_size, + cone_angle, + n_samples, + True, + ray_mask, + # pre-compute intersections + t_sorted, # [n_rays, m*2] + t_indices, # [n_rays, m*2] + hits, # [n_rays, m] + ) + t_starts = intervals.vals[intervals.is_left] + t_ends = intervals.vals[intervals.is_right] + ray_indices = samples.ray_indices[samples.is_valid] + packed_info = samples.packed_info + + # get rgb and sigma from radiance field + rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices) + # volume rendering using native cuda scan + weights, _, alphas = render_weight_from_density( + t_starts, + t_ends, + sigmas, + ray_indices=ray_indices, + n_rays=num_rays, + prefix_trans=1 - opacity[ray_indices].squeeze(-1), + ) + if alpha_thre > 0: + vis_mask = alphas >= alpha_thre + ray_indices, rgbs, weights, t_starts, t_ends = ( + ray_indices[vis_mask], + rgbs[vis_mask], + weights[vis_mask], + t_starts[vis_mask], + t_ends[vis_mask], + ) + + accumulate_along_rays_( + weights, values=rgbs, ray_indices=ray_indices, outputs=rgb, + ) + accumulate_along_rays_( + weights, values=None, ray_indices=ray_indices, outputs=opacity, + ) + accumulate_along_rays_( + weights, values=(t_starts + t_ends)[..., None] / 2.0, ray_indices=ray_indices, outputs=depth, + ) + # update near_planes using termination planes + near_planes = termination_planes + # update rays status + ray_mask = torch.logical_and( + # early stopping + opacity.view(-1) <= opc_thre, + # remove rays that have reached the far plane + packed_info[:, 1] == n_samples, + ) + total_samples += ray_indices.shape[0] + + if render_bkgd is not None: + rgb = rgb + render_bkgd * (1.0 - opacity) + + depth = depth / opacity.clamp_min(torch.finfo(rgbs.dtype).eps) + + return ( + rgb.view((*rays_shape[:-1], -1)), + opacity.view((*rays_shape[:-1], -1)), + depth.view((*rays_shape[:-1], -1)), + weights, + alphas, + total_samples, + ) + + +class NerfaccVolumeBaseRenderer(BaseRenderer): + def __init__( + self, + bound, + grid_resolution, + grid_levels, + render_step_size=1e-3, + near_plane=0.2, + cone_angle=0.004, + alpha_thre=1e-2, + ): + + super().__init__(bound) + + self.grid_resolution = grid_resolution + self.grid_levels = grid_levels + self.render_step_size = render_step_size + self.near_plane = near_plane + self.cone_angle = cone_angle + self.alpha_thre = alpha_thre + self.nerf = None + + self.estimator = OccGridEstimator(roi_aabb=self.aabb, resolution=self.grid_resolution, levels=self.grid_levels) + + @torch.no_grad() # TODO(ahmadki) + def update_step( + self, + epoch: int, + global_step: int, + update_interval: int = 16, + decay: float = 0.95, + occ_thre: float = 0.01, + warmup_steps: int = 256, + **kwargs + ): + def occ_eval_fn(x): + density = self.nerf.forward_density(x) + return density * self.render_step_size + + self.estimator.update_every_n_steps( + step=global_step, + occ_eval_fn=occ_eval_fn, + occ_thre=occ_thre, + ema_decay=decay, + warmup_steps=warmup_steps, + n=update_interval, + ) + + def forward(self, rays_o, rays_d, mvp, h, w, staged=False, max_ray_batch=4096, step=None, **kwargs): + return self._render(rays_o=rays_o, rays_d=rays_d, step=step, **kwargs) + + def _render( + self, + rays_o, + rays_d, + light_d=None, + ambient_ratio=1.0, + shading='albedo', + bg_color=None, + perturb=False, + T_thresh=1e-4, + binarize=False, + step=None, + **kwargs + ): + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact + + rays = Rays(origins=rays_o, viewdirs=rays_d) + + if self.training: + rgb, acc, depth, weights, alphas, n_rendering_samples = render_image_with_occgrid( + nerf=self.nerf, + estimator=self.estimator, + rays=rays, + near_plane=self.near_plane, + render_step_size=self.render_step_size, + render_bkgd=bg_color, + cone_angle=self.cone_angle, + alpha_thre=self.alpha_thre, + ) + else: + rgb, acc, depth, weights, alphas, n_rendering_samples = render_image_with_occgrid_test( + max_samples=1024, + nerf=self.nerf, + estimator=self.estimator, + rays=rays, + near_plane=self.near_plane, + render_step_size=self.render_step_size, + render_bkgd=bg_color, + cone_angle=self.cone_angle, + alpha_thre=self.alpha_thre, + ) + + results = {} + results['weights'] = weights + results['image'] = rgb.view(1, -1, 3) + results['depth'] = depth.view(1, -1) + results['weights_sum'] = acc.view(1, -1) + + return results diff --git a/nemo/collections/multimodal/modules/nerf/renderers/nvdiffrast_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/nvdiffrast_renderer.py new file mode 100644 index 000000000000..ef8472c9211c --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/nvdiffrast_renderer.py @@ -0,0 +1,235 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np +import nvdiffrast.torch as dr +import torch +import torch.nn.functional as F + +from nemo.collections.multimodal.modules.nerf.geometry.dmtet import DeepMarchingTetrahedra +from nemo.collections.multimodal.modules.nerf.geometry.nerf_base import DensityActivationEnum +from nemo.collections.multimodal.modules.nerf.renderers.base_renderer import BaseRenderer + + +# TODO: self.density_thresh, self.mean_density need a rework, they can be infered at run time +# and shouldn't be loaded from the checkpoint +class NVDiffRastRenderer(BaseRenderer): + def __init__(self, bound, update_interval, grid_resolution, density_thresh, quartet_file): + + super().__init__(bound, update_interval) + + self.grid_resolution = grid_resolution + self.density_thresh = density_thresh + self.quartet_file = quartet_file + + self.cascade = 1 + math.ceil(math.log2(bound)) + density_grid = torch.zeros([self.cascade, self.grid_resolution ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros( + self.cascade * self.grid_resolution ** 3 // 8, dtype=torch.uint8 + ) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + + # load dmtet vertices + # TODO(ahmadki): hard coded devices + tets = np.load(quartet_file) + self.verts = -torch.tensor(tets['vertices'], dtype=torch.float32, device='cuda') * 2 # covers [-1, 1] + self.indices = torch.tensor(tets['indices'], dtype=torch.long, device='cuda') + self.tet_scale = torch.tensor([1, 1, 1], dtype=torch.float32, device='cuda') + self.dmtet = DeepMarchingTetrahedra(device='cuda') + + # vert sdf and deform + sdf = torch.nn.Parameter(torch.zeros_like(self.verts[..., 0]), requires_grad=True) + self.register_parameter('sdf', sdf) + deform = torch.nn.Parameter(torch.zeros_like(self.verts), requires_grad=True) + self.register_parameter('deform', deform) + + edges = torch.tensor( + [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device="cuda" + ) # six edges for each tetrahedron. + all_edges = self.indices[:, edges].reshape(-1, 2) # [M * 6, 2] + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + self.initialized = False # TODO(ahmadki): not a good approach + + self.glctx = dr.RasterizeCudaContext() + + # TODO(ahmadki): not a good approach + self.nerf = None + self.material = None + self.background = None + + # TODO(ahmkadi): doesn't look good to me !! + @torch.no_grad() + def update_step(self, epoch: int, global_step: int, decay: float = 0.95, S: int = 128, **kwargs): + pass + + @torch.no_grad() + def init_tet(self): + # TODO(ahmadki): a better approach would be to have a global nerf representation (mesh) that + # we can init the tets from. this would work with checkpoints. + + # TODO(ahmadki): a placeholder, but it works for now + self.mean_density = 300 + density_thresh = min(self.mean_density, self.density_thresh) + + if self.nerf.density_activation == DensityActivationEnum.SOFTPLUS: + density_thresh = density_thresh * 25 + + # Get initial sigma + sigma = self.nerf.forward_density(positions=self.verts) + mask = sigma > density_thresh + valid_verts = self.verts[mask] + self.tet_scale = valid_verts.abs().amax(dim=0) + 1e-1 + + # Scale vertices + self.verts = self.verts * self.tet_scale + + # get sigma using the scaled vertices + sigma = self.nerf.forward_density(positions=self.verts) + self.sdf.data += (sigma - density_thresh).clamp(-1, 1) + + def forward( + self, + rays_o, + rays_d, + mvp, + light_d=None, + ambient_ratio=1.0, + shading_type=None, + return_normal_image=False, + return_vertices=False, + return_faces=False, + return_faces_normals=False, + **kwargs + ): + if not self.initialized: + self.init_tet() + self.initialized = True + return self._render( + rays_o=rays_o, + rays_d=rays_d, + mvp=mvp, + light_d=light_d, + ambient_ratio=ambient_ratio, + shading_type=shading_type, + return_normal_image=return_normal_image, + return_vertices=return_vertices, + return_faces=return_faces, + return_faces_normals=return_faces_normals, + **kwargs + ) + + def _render( + self, + rays_o, + rays_d, + mvp, + light_d=None, + ambient_ratio=1.0, + shading_type=None, + return_normal_image=False, + return_vertices=False, + return_faces=False, + return_faces_normals=False, + **kwargs + ): + # mvp: [B, 4, 4] + B, H, W, _ = rays_o.shape + + # TODO(ahmadki): move to dataset + # random sample light_d if not provided + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = rays_o + torch.randn(3, device=rays_o.device) + light_d = F.normalize(light_d) + + results = {} + + # get mesh + deform = torch.tanh(self.deform) / self.grid_resolution + + verts, faces = self.dmtet(self.verts + deform, self.sdf, self.indices) + + # get normals + i0, i1, i2 = faces[:, 0], faces[:, 1], faces[:, 2] + v0, v1, v2 = verts[i0, :], verts[i1, :], verts[i2, :] + + faces = faces.int() + + face_normals = torch.cross(v1 - v0, v2 - v0) + face_normals = F.normalize(face_normals) + + vn = torch.zeros_like(verts) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + vn = torch.where( + torch.sum(vn * vn, -1, keepdim=True) > 1e-20, + vn, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), + ) + + # rasterization + verts_clip = torch.bmm( + F.pad(verts, pad=(0, 1), mode='constant', value=1.0).unsqueeze(0).repeat(mvp.shape[0], 1, 1), + mvp.permute(0, 2, 1), + ).float() # [B, N, 4] + rast, _ = dr.rasterize(self.glctx, verts_clip, faces, (H, W)) + + alpha = (rast[..., 3:] > 0).float() + xyzs, _ = dr.interpolate(verts.unsqueeze(0), rast, faces) # [B, H, W, 3] + normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, faces) + normal = F.normalize(normal) + + xyzs = xyzs.view(-1, 3) + mask = (rast[..., 3:] > 0).view(-1).detach() + + # do the lighting here since we have normal from mesh now. + albedo = torch.zeros_like(xyzs, dtype=torch.float32) + if mask.any(): + masked_albedo = self.nerf.forward_features(positions=xyzs[mask]) + albedo[mask] = masked_albedo.float() + albedo = albedo.view(B, H, W, 3) + fg_color = self.material( + albedo=albedo, normals=normal, light_d=light_d, ambient_ratio=ambient_ratio, shading_type=shading_type + ) + + fg_color = dr.antialias(fg_color, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 3] + alpha = dr.antialias(alpha, rast, verts_clip, faces).clamp(0, 1) # [B, H, W, 1] + + # mix background color + bg_color = self.background(rays_d=rays_d) # [N, 3] + + depth = rast[:, :, :, [2]] # [B, H, W] + color = fg_color + (1 - alpha) * bg_color + + results['depth'] = depth + results['image'] = color + if return_normal_image: + results['normal_image'] = dr.antialias((normal + 1) / 2, rast, verts_clip, faces).clamp( + 0, 1 + ) # [B, H, W, 3] + if return_vertices: + results['vertices'] = verts + if return_faces: + results['faces'] = faces + if return_faces_normals: + results['face_normals'] = face_normals + return results diff --git a/nemo/collections/multimodal/modules/nerf/renderers/torchngp_volume_renderer.py b/nemo/collections/multimodal/modules/nerf/renderers/torchngp_volume_renderer.py new file mode 100644 index 000000000000..da66f578ed74 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/renderers/torchngp_volume_renderer.py @@ -0,0 +1,288 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +import torch.nn.functional as F + +import nemo.collections.multimodal.modules.nerf.utils.torch_ngp.raymarching as raymarching +from nemo.collections.multimodal.modules.nerf.materials.materials_base import ShadingEnum +from nemo.collections.multimodal.modules.nerf.renderers.base_renderer import BaseRenderer + + +class TorchNGPVolumeRenderer(BaseRenderer): + def __init__(self, bound, update_interval, grid_resolution, density_thresh, max_steps, dt_gamma): + + super().__init__(bound, update_interval) + + self.cascade = 1 + math.ceil(math.log2(bound)) + self.grid_resolution = grid_resolution + self.density_thresh = density_thresh + self.dt_gamma = dt_gamma + self.max_steps = max_steps + + # density grid + # TODO(ahmadki): needs rework + density_grid = torch.zeros([self.cascade, self.grid_resolution ** 3]) # [CAS, H * H * H] + density_bitfield = torch.zeros( + self.cascade * self.grid_resolution ** 3 // 8, dtype=torch.uint8 + ) # [CAS * H * H * H // 8] + self.register_buffer('density_grid', density_grid) + self.register_buffer('density_bitfield', density_bitfield) + self.mean_density = 0 + self.iter_density = 0 + + # TODO(ahmadki): needs rework + self.nerf = None + self.material = None + self.background = None + + @torch.no_grad() + def update_step(self, epoch: int, global_step: int, decay: float = 0.95, S: int = 128, **kwargs): + if global_step % self.update_interval != 0: + return + + ### update density grid + tmp_grid = -torch.ones_like(self.density_grid) + + X = torch.arange(self.grid_resolution, dtype=torch.int32, device=self.aabb.device).split(S) + Y = torch.arange(self.grid_resolution, dtype=torch.int32, device=self.aabb.device).split(S) + Z = torch.arange(self.grid_resolution, dtype=torch.int32, device=self.aabb.device).split(S) + + for xs in X: + for ys in Y: + for zs in Z: + + # construct points + xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij') + coords = torch.cat( + [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1 + ) # [N, 3], in [0, 128) + indices = raymarching.morton3D(coords).long() # [N] + xyzs = 2 * coords.float() / (self.grid_resolution - 1) - 1 # [N, 3] in [-1, 1] + + # cascading + for cas in range(self.cascade): + bound = min(2 ** cas, self.bound) + half_grid_resolution = bound / self.grid_resolution + # scale to current cascade's resolution + cas_xyzs = xyzs * (bound - half_grid_resolution) + # add noise in [-hgs, hgs] + cas_xyzs += (torch.rand_like(cas_xyzs) * 2 - 1) * half_grid_resolution + # query density + density = self.nerf.forward_density(cas_xyzs).reshape(-1).detach() + # assign + tmp_grid[cas, indices] = density + # ema update + valid_mask = self.density_grid >= 0 + self.density_grid[valid_mask] = torch.maximum(self.density_grid[valid_mask] * decay, tmp_grid[valid_mask]) + self.mean_density = torch.mean(self.density_grid[valid_mask]).item() + self.iter_density += 1 + + # convert to bitfield + density_thresh = min(self.mean_density, self.density_thresh) + self.density_bitfield = raymarching.packbits(self.density_grid, density_thresh, self.density_bitfield) + + def forward( + self, + rays_o, + rays_d, + light_d=None, + ambient_ratio=1.0, + shading_type=None, + return_normal_image=False, + return_normal_perturb=False, + **kwargs + ): + return self._render( + rays_o=rays_o, + rays_d=rays_d, + light_d=light_d, + ambient_ratio=ambient_ratio, + shading_type=shading_type, + return_normal_image=return_normal_image, + return_normal_perturb=return_normal_perturb, + **kwargs + ) + + # TODO(ahmadki): return_normal_image is always False ? + def _render( + self, + rays_o, + rays_d, + light_d=None, + ambient_ratio=1.0, + shading_type=None, + return_normal_image=False, + return_normal_perturb=False, + perturb=False, + T_thresh=1e-4, + binarize=False, + **kwargs + ): + # rays_o, rays_d: [B, H, W, 3] + B, H, W, _ = rays_o.shape + + # group all rays into a single batch + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + num_rays = rays_o.shape[0] # num_rays = B * H * W + + # pre-calculate near far + nears, fars = raymarching.near_far_from_aabb(rays_o, rays_d, self.aabb) + + # random sample light_d if not provided + # TODO(ahmadki): move to dataset + if light_d is None: + # gaussian noise around the ray origin, so the light always face the view dir (avoid dark face) + light_d = rays_o + torch.randn(3, device=rays_o.device) + light_d = F.normalize(light_d) + + normal_image = None + normals_perturb = None + weights = None + + if self.training: + positions, dirs, ts, rays = raymarching.march_rays_train( + rays_o, + rays_d, + self.bound, + self.density_bitfield, + self.cascade, + self.grid_resolution, + nears, + fars, + perturb, + self.dt_gamma, + self.max_steps, + ) + dirs = F.normalize(dirs) + + if light_d.shape[0] > 1: + flatten_rays = raymarching.flatten_rays(rays, positions.shape[0]).long() + light_d = light_d[flatten_rays] + + return_normal = (shading_type is not None) or return_normal_image + sigmas, albedo, normals = self.nerf(positions=positions, return_normal=return_normal) + + fg_color = self.material( + albedo=albedo, normals=normals, light_d=light_d, ambient_ratio=ambient_ratio, shading_type=shading_type + ) + + weights, opacity, depth, image = raymarching.composite_rays_train( + sigmas, fg_color, ts, rays, T_thresh, binarize + ) + + if return_normal_image and normals is not None: + _, _, _, normal_image = raymarching.composite_rays_train( + sigmas.detach(), (normals + 1) / 2, ts, rays, T_thresh, binarize + ) + + if return_normal_perturb: + perturb_positions = positions + torch.randn_like(positions) * 1e-2 + normals_perturb = self.normal(positions=perturb_positions) + + else: + # allocate tensors + image = torch.zeros(num_rays, 3, device=rays_o.device) + depth = torch.zeros(num_rays, device=rays_o.device) + opacity = torch.zeros(num_rays, device=rays_o.device) + + n_alive = num_rays + rays_alive = torch.arange(n_alive, dtype=torch.int32, device=rays_o.device) + rays_t = nears.clone() + + step = 0 + + while step < self.max_steps: # hard coded max step + # count alive rays + n_alive = rays_alive.shape[0] + + # exit loop + if n_alive <= 0: + break + + # decide compact_steps + n_step = max(min(num_rays // n_alive, 8), 1) + + positions, dirs, ts = raymarching.march_rays( + n_alive, + n_step, + rays_alive, + rays_t, + rays_o, + rays_d, + self.bound, + self.density_bitfield, + self.cascade, + self.grid_resolution, + nears, + fars, + perturb if step == 0 else False, + self.dt_gamma, + self.max_steps, + ) + dirs = F.normalize(dirs) + + return_normal = shading_type not in [None, ShadingEnum.TEXTURELESS] + sigmas, albedo, normals = self.nerf(positions=positions, return_normal=return_normal) + + fg_color = self.material( + albedo=albedo, + normals=normals, + light_d=light_d, + ambient_ratio=ambient_ratio, + shading_type=shading_type, + ) + raymarching.composite_rays( + n_alive, + n_step, + rays_alive, + rays_t, + sigmas, + fg_color, + ts, + opacity, + depth, + image, + T_thresh, + binarize, + ) + + # TODO(ahmadki): add optoin to return normal_image, like in training + + rays_alive = rays_alive[rays_alive >= 0] + + step += n_step + + # mix background color + bg_color = self.background(rays_d) # [N, 3] + image = image + (1 - opacity).unsqueeze(-1) * bg_color + + results = { + "image": image.view(B, H, W, 3), + "depth": depth.view(B, H, W, 1), + "opacity": opacity.view(B, H, W, 1), + "dirs": dirs, + } + if normals is not None: + results["normals"] = normals + if weights is not None: + results["weights"] = weights + if normal_image is not None: + results["normal_image"] = normal_image.view(B, H, W, 3) + if normals_perturb is not None: + results["normal_perturb"] = normals_perturb + + return results diff --git a/nemo/collections/multimodal/modules/nerf/utils/__init__.py b/nemo/collections/multimodal/modules/nerf/utils/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/utils/activation.py b/nemo/collections/multimodal/modules/nerf/utils/activation.py new file mode 100644 index 000000000000..1b79676d57c6 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/activation.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _trunc_exp(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float) + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(max=15)) + + +trunc_exp = _trunc_exp.apply diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/__init__.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/encoding.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/encoding.py new file mode 100644 index 000000000000..59c6caa033c2 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/encoding.py @@ -0,0 +1,149 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + + +class FreqEncoder_torch(nn.Module): + def __init__( + self, + input_dim, + max_freq_log2, + N_freqs, + log_sampling=True, + include_input=True, + periodic_fns=(torch.sin, torch.cos), + ): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + self.N_freqs = N_freqs + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2 ** 0, 2 ** max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, max_level=None, **kwargs): + + if max_level is None: + max_level = self.N_freqs + else: + max_level = int(max_level * self.N_freqs) + + out = [] + if self.include_input: + out.append(input) + + for i in range(max_level): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + # append 0 + if self.N_freqs - max_level > 0: + out.append( + torch.zeros( + input.shape[0], + (self.N_freqs - max_level) * 2 * input.shape[1], + device=input.device, + dtype=input.dtype, + ) + ) + + out = torch.cat(out, dim=-1) + + return out + + +def get_encoder( + encoder_type, + input_dim=3, + multires=6, + degree=4, + num_levels=16, + level_dim=2, + base_resolution=16, + log2_hashmap_size=19, + desired_resolution=2048, + align_corners=False, + interpolation='linear', + **kwargs +): + + if encoder_type is None: + return lambda x, **kwargs: x, input_dim + + elif encoder_type == 'frequency_torch': + encoder = FreqEncoder_torch( + input_dim=input_dim, max_freq_log2=multires - 1, N_freqs=multires, log_sampling=True + ) + + elif encoder_type == 'frequency': # CUDA implementation, faster than torch. + from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.freqencoder import FreqEncoder + + encoder = FreqEncoder(input_dim=input_dim, degree=multires) + + elif encoder_type == 'sphere_harmonics': + from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.shencoder import SHEncoder + + encoder = SHEncoder(input_dim=input_dim, degree=degree) + + elif encoder_type == 'hashgrid': + from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.gridencoder import GridEncoder + + encoder = GridEncoder( + input_dim=input_dim, + num_levels=num_levels, + level_dim=level_dim, + base_resolution=base_resolution, + log2_hashmap_size=log2_hashmap_size, + desired_resolution=desired_resolution, + gridtype='hash', + align_corners=align_corners, + interpolation=interpolation, + ) + + elif encoder_type == 'tiledgrid': + from nemo.collections.multimodal.modules.nerf.utils.torch_ngp.gridencoder import GridEncoder + + encoder = GridEncoder( + input_dim=input_dim, + num_levels=num_levels, + level_dim=level_dim, + base_resolution=base_resolution, + log2_hashmap_size=log2_hashmap_size, + desired_resolution=desired_resolution, + gridtype='tiled', + align_corners=align_corners, + interpolation=interpolation, + ) + + else: + raise NotImplementedError( + 'Unknown encoder type, choose from [None, frequency, sphere_harmonics, hashgrid, tiledgrid]' + ) + + return encoder, encoder.output_dim diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/freqencoder.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/freqencoder.py new file mode 100644 index 000000000000..f426174408f3 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/freqencoder.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _freqencoder as _backend +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _freq_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, output_dim): + # inputs: [B, input_dim], float + # RETURN: [B, F], float + + if not inputs.is_cuda: + inputs = inputs.cuda() + inputs = inputs.contiguous() + + B, input_dim = inputs.shape # batch size, coord dim + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + _backend.freq_encode_forward(inputs, B, input_dim, degree, output_dim, outputs) + + ctx.save_for_backward(inputs, outputs) + ctx.dims = [B, input_dim, degree, output_dim] + + return outputs + + @staticmethod + # @once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + grad = grad.contiguous() + inputs, outputs = ctx.saved_tensors + B, input_dim, degree, output_dim = ctx.dims + + grad_inputs = torch.zeros_like(inputs) + _backend.freq_encode_backward(grad, outputs, B, input_dim, degree, output_dim, grad_inputs) + + return grad_inputs, None, None + + +freq_encode = _freq_encoder.apply + + +class FreqEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim + self.degree = degree + self.output_dim = input_dim + input_dim * 2 * degree + + def __repr__(self): + return f"FreqEncoder: input_dim={self.input_dim} degree={self.degree} output_dim={self.output_dim}" + + def forward(self, inputs, **kwargs): + # inputs: [..., input_dim] + # return: [..., ] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = freq_encode(inputs, self.degree, self.output_dim) + + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/gridencoder.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/gridencoder.py new file mode 100644 index 000000000000..be173fb1e98f --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/gridencoder.py @@ -0,0 +1,299 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import _gridencoder as _backend +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward( + ctx, + inputs, + embeddings, + offsets, + per_level_scale, + base_resolution, + calc_grad_inputs=False, + gridtype=0, + align_corners=False, + interpolation=0, + max_level=None, + ): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + max_level = L if max_level is None else max(min(int(math.ceil(max_level * L)), L), 1) + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + # zero init if we only calculate partial levels + if max_level < L: + outputs.zero_() + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + if max_level < L: + dy_dx.zero_() + else: + dy_dx = None + + _backend.grid_encode_forward( + inputs, + embeddings, + offsets, + outputs, + B, + D, + C, + L, + max_level, + S, + H, + dy_dx, + gridtype, + align_corners, + interpolation, + ) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation, max_level] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + # @once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation, max_level = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward( + grad, + inputs, + embeddings, + offsets, + grad_embeddings, + B, + D, + C, + L, + max_level, + S, + H, + dy_dx, + grad_inputs, + gridtype, + align_corners, + interpolation, + ) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None, None + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__( + self, + input_dim=3, + num_levels=16, + level_dim=2, + per_level_scale=2, + base_resolution=16, + log2_hashmap_size=19, + desired_resolution=None, + gridtype='hash', + align_corners=False, + interpolation='linear', + ): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + + # allocate parameters + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = 1e-4 + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1, max_level=None): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # max_level: only calculate first max_level levels (None will use all levels) + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + + # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode( + inputs, + self.embeddings, + self.offsets, + self.per_level_scale, + self.base_resolution, + inputs.requires_grad, + self.gridtype_id, + self.align_corners, + self.interp_id, + max_level, + ) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation( + inputs, + self.embeddings, + self.embeddings.grad, + self.offsets, + weight, + B, + D, + C, + L, + S, + H, + self.gridtype_id, + self.align_corners, + ) + + @torch.cuda.amp.autocast(enabled=False) + def grad_weight_decay(self, weight=0.1): + # level-wise meaned weight decay (ref: zip-nerf) + + B = self.embeddings.shape[0] # size of embedding + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_weight_decay(self.embeddings, self.embeddings.grad, self.offsets, weight, B, C, L) diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/raymarching.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/raymarching.py new file mode 100644 index 000000000000..2c414c554203 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/raymarching.py @@ -0,0 +1,561 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +# lazy building: +# `import raymarching` will not immediately build the extension, only if you actually call any functions. + +BACKEND = None + + +def get_backend(): + global BACKEND + + if BACKEND is None: + try: + import _raymarching as _backend + except ImportError: + from .backend import _backend + + BACKEND = _backend + + return BACKEND + + +# ---------------------------------------- +# utils +# ---------------------------------------- + + +class _near_far_from_aabb(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, aabb, min_near=0.2): + ''' near_far_from_aabb, CUDA implementation + Calculate rays' intersection time (near and far) with aabb + Args: + rays_o: float, [N, 3] + rays_d: float, [N, 3] + aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax) + min_near: float, scalar + Returns: + nears: float, [N] + fars: float, [N] + ''' + if not rays_o.is_cuda: + rays_o = rays_o.cuda() + if not rays_d.is_cuda: + rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars) + + return nears, fars + + +near_far_from_aabb = _near_far_from_aabb.apply + + +class _sph_from_ray(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, rays_o, rays_d, radius): + ''' sph_from_ray, CUDA implementation + get spherical coordinate on the background sphere from rays. + Assume rays_o are inside the Sphere(radius). + Args: + rays_o: [N, 3] + rays_d: [N, 3] + radius: scalar, float + Return: + coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface) + ''' + if not rays_o.is_cuda: + rays_o = rays_o.cuda() + if not rays_d.is_cuda: + rays_d = rays_d.cuda() + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # num rays + + coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords) + + return coords + + +sph_from_ray = _sph_from_ray.apply + + +class _morton3D(Function): + @staticmethod + def forward(ctx, coords): + ''' morton3D, CUDA implementation + Args: + coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...) + TODO: check if the coord range is valid! (current 128 is safe) + Returns: + indices: [N], int32, in [0, 128^3) + + ''' + if not coords.is_cuda: + coords = coords.cuda() + + N = coords.shape[0] + + indices = torch.empty(N, dtype=torch.int32, device=coords.device) + + get_backend().morton3D(coords.int(), N, indices) + + return indices + + +morton3D = _morton3D.apply + + +class _morton3D_invert(Function): + @staticmethod + def forward(ctx, indices): + ''' morton3D_invert, CUDA implementation + Args: + indices: [N], int32, in [0, 128^3) + Returns: + coords: [N, 3], int32, in [0, 128) + + ''' + if not indices.is_cuda: + indices = indices.cuda() + + N = indices.shape[0] + + coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device) + + get_backend().morton3D_invert(indices.int(), N, coords) + + return coords + + +morton3D_invert = _morton3D_invert.apply + + +class _packbits(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, grid, thresh, bitfield=None): + ''' packbits, CUDA implementation + Pack up the density grid into a bit field to accelerate ray marching. + Args: + grid: float, [C, H * H * H], assume H % 2 == 0 + thresh: float, threshold + Returns: + bitfield: uint8, [C, H * H * H / 8] + ''' + if not grid.is_cuda: + grid = grid.cuda() + grid = grid.contiguous() + + C = grid.shape[0] + H3 = grid.shape[1] + N = C * H3 // 8 + + if bitfield is None: + bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device) + + get_backend().packbits(grid, N, thresh, bitfield) + + return bitfield + + +packbits = _packbits.apply + + +class _flatten_rays(Function): + @staticmethod + def forward(ctx, rays, M): + ''' flatten rays + Args: + rays: [N, 2], all rays' (point_offset, point_count), + M: scalar, int, count of points (we cannot get this info from rays unfortunately...) + Returns: + res: [M], flattened ray index. + ''' + if not rays.is_cuda: + rays = rays.cuda() + rays = rays.contiguous() + + N = rays.shape[0] + + res = torch.zeros(M, dtype=torch.int, device=rays.device) + + get_backend().flatten_rays(rays, N, M, res) + + return res + + +flatten_rays = _flatten_rays.apply + +# ---------------------------------------- +# train functions +# ---------------------------------------- + + +class _march_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + rays_o, + rays_d, + bound, + density_bitfield, + C, + H, + nears, + fars, + perturb=False, + dt_gamma=0, + max_steps=1024, + contract=False, + ): + ''' march rays to generate points (forward only) + Args: + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + step_counter: int32, (2), used to count the actual number of generated points. + mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.) + perturb: bool + align: int, pad output so its size is dividable by align, set to -1 to disable. + force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray) + dirs: float, [M, 3], all generated points' view dirs. + ts: float, [M, 2], all generated points' ts. + rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0] + ''' + + if not rays_o.is_cuda: + rays_o = rays_o.cuda() + if not rays_d.is_cuda: + rays_d = rays_d.cuda() + if not density_bitfield.is_cuda: + density_bitfield = density_bitfield.cuda() + + rays_o = rays_o.float().contiguous().view(-1, 3) + rays_d = rays_d.float().contiguous().view(-1, 3) + density_bitfield = density_bitfield.contiguous() + + N = rays_o.shape[0] # num rays + + step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + if perturb: + noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device) + + # first pass: write rays, get total number of points M to render + rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps + get_backend().march_rays_train( + rays_o, + rays_d, + density_bitfield, + bound, + contract, + dt_gamma, + max_steps, + N, + C, + H, + nears, + fars, + None, + None, + None, + rays, + step_counter, + noises, + ) + + # allocate based on M + M = step_counter.item() + # print(M, N) + # print(rays[:, 0].max()) + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) + + # second pass: write outputs + get_backend().march_rays_train( + rays_o, + rays_d, + density_bitfield, + bound, + contract, + dt_gamma, + max_steps, + N, + C, + H, + nears, + fars, + xyzs, + dirs, + ts, + rays, + step_counter, + noises, + ) + + return xyzs, dirs, ts, rays + + +march_rays_train = _march_rays_train.apply + + +class _composite_rays_train(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False): + ''' composite rays' rgbs, according to the ray marching formula. + Args: + rgbs: float, [M, 3] + sigmas: float, [M,] + ts: float, [M, 2] + rays: int32, [N, 3] + Returns: + weights: float, [M] + weights_sum: float, [N,], the alpha channel + depth: float, [N, ], the Depth + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + + sigmas = sigmas.float().contiguous() + rgbs = rgbs.float().contiguous() + + M = sigmas.shape[0] + N = rays.shape[0] + + weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0 + weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + + depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) + image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) + + get_backend().composite_rays_train_forward( + sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image + ) + + ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image) + ctx.dims = [M, N, T_thresh, binarize] + + return weights, weights_sum, depth, image + + @staticmethod + @custom_bwd + def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image): + + grad_weights = grad_weights.contiguous() + grad_weights_sum = grad_weights_sum.contiguous() + grad_depth = grad_depth.contiguous() + grad_image = grad_image.contiguous() + + sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors + M, N, T_thresh, binarize = ctx.dims + + grad_sigmas = torch.zeros_like(sigmas) + grad_rgbs = torch.zeros_like(rgbs) + + get_backend().composite_rays_train_backward( + grad_weights, + grad_weights_sum, + grad_depth, + grad_image, + sigmas, + rgbs, + ts, + rays, + weights_sum, + depth, + image, + M, + N, + T_thresh, + binarize, + grad_sigmas, + grad_rgbs, + ) + + return grad_sigmas, grad_rgbs, None, None, None, None + + +composite_rays_train = _composite_rays_train.apply + +# ---------------------------------------- +# infer functions +# ---------------------------------------- + + +class _march_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward( + ctx, + n_alive, + n_step, + rays_alive, + rays_t, + rays_o, + rays_d, + bound, + density_bitfield, + C, + H, + near, + far, + perturb=False, + dt_gamma=0, + max_steps=1024, + contract=False, + ): + ''' march rays to generate points (forward only, for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive) + rays_t: float, [N], the alive rays' time, we only use the first n_alive. + rays_o/d: float, [N, 3] + bound: float, scalar + density_bitfield: uint8: [CHHH // 8] + C: int + H: int + nears/fars: float, [N] + align: int, pad output so its size is dividable by align, set to -1 to disable. + perturb: bool/int, int > 0 is used as the random seed. + dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance) + max_steps: int, max number of sampled points along each ray, also affect min_stepsize. + Returns: + xyzs: float, [n_alive * n_step, 3], all generated points' coords + dirs: float, [n_alive * n_step, 3], all generated points' view dirs. + ts: float, [n_alive * n_step, 2], all generated points' ts + ''' + + if not rays_o.is_cuda: + rays_o = rays_o.cuda() + if not rays_d.is_cuda: + rays_d = rays_d.cuda() + + rays_o = rays_o.float().contiguous().view(-1, 3) + rays_d = rays_d.float().contiguous().view(-1, 3) + + M = n_alive * n_step + + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth + + if perturb: + # torch.manual_seed(perturb) # test_gui uses spp index as seed + noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device) + else: + noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device) + + get_backend().march_rays( + n_alive, + n_step, + rays_alive, + rays_t, + rays_o, + rays_d, + bound, + contract, + dt_gamma, + max_steps, + C, + H, + density_bitfield, + near, + far, + xyzs, + dirs, + ts, + noises, + ) + + return xyzs, dirs, ts + + +march_rays = _march_rays.apply + + +class _composite_rays(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float + def forward( + ctx, + n_alive, + n_step, + rays_alive, + rays_t, + sigmas, + rgbs, + ts, + weights_sum, + depth, + image, + T_thresh=1e-2, + binarize=False, + ): + ''' composite rays' rgbs, according to the ray marching formula. (for inference) + Args: + n_alive: int, number of alive rays + n_step: int, how many steps we march + rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive) + rays_t: float, [N], the alive rays' time + sigmas: float, [n_alive * n_step,] + rgbs: float, [n_alive * n_step, 3] + ts: float, [n_alive * n_step, 2] + In-place Outputs: + weights_sum: float, [N,], the alpha channel + depth: float, [N,], the depth value + image: float, [N, 3], the RGB channel (after multiplying alpha!) + ''' + sigmas = sigmas.float().contiguous() + rgbs = rgbs.float().contiguous() + get_backend().composite_rays( + n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image + ) + return tuple() + + +composite_rays = _composite_rays.apply diff --git a/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/shencoder.py b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/shencoder.py new file mode 100644 index 000000000000..446b58473a3e --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/torch_ngp/shencoder.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import _shencoder as _backend +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _sh_encoder(Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float32) # force float32 for better precision + def forward(ctx, inputs, degree, calc_grad_inputs=False): + # inputs: [B, input_dim], float in [-1, 1] + # RETURN: [B, F], float + + inputs = inputs.contiguous() + B, input_dim = inputs.shape # batch size, coord dim + output_dim = degree ** 2 + + outputs = torch.empty(B, output_dim, dtype=inputs.dtype, device=inputs.device) + + if calc_grad_inputs: + dy_dx = torch.empty(B, input_dim * output_dim, dtype=inputs.dtype, device=inputs.device) + else: + dy_dx = None + + _backend.sh_encode_forward(inputs, outputs, B, input_dim, degree, dy_dx) + + ctx.save_for_backward(inputs, dy_dx) + ctx.dims = [B, input_dim, degree] + + return outputs + + @staticmethod + # @once_differentiable + @custom_bwd + def backward(ctx, grad): + # grad: [B, C * C] + + inputs, dy_dx = ctx.saved_tensors + + if dy_dx is not None: + grad = grad.contiguous() + B, input_dim, degree = ctx.dims + grad_inputs = torch.zeros_like(inputs) + _backend.sh_encode_backward(grad, inputs, B, input_dim, degree, dy_dx, grad_inputs) + return grad_inputs, None, None + else: + return None, None, None + + +sh_encode = _sh_encoder.apply + + +class SHEncoder(nn.Module): + def __init__(self, input_dim=3, degree=4): + super().__init__() + + self.input_dim = input_dim # coord dims, must be 3 + self.degree = degree # 0 ~ 4 + self.output_dim = degree ** 2 + + assert self.input_dim == 3, "SH encoder only support input dim == 3" + assert self.degree > 0 and self.degree <= 8, "SH encoder only supports degree in [1, 8]" + + def __repr__(self): + return f"SHEncoder: input_dim={self.input_dim} degree={self.degree}" + + def forward(self, inputs, size=1): + # inputs: [..., input_dim], normalized real world positions in [-size, size] + # return: [..., degree^2] + + inputs = inputs / size # [-1, 1] + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.reshape(-1, self.input_dim) + + outputs = sh_encode(inputs, self.degree, inputs.requires_grad) + outputs = outputs.reshape(prefix_shape + [self.output_dim]) + + return outputs diff --git a/nemo/collections/multimodal/modules/nerf/utils/trt_engine.py b/nemo/collections/multimodal/modules/nerf/utils/trt_engine.py new file mode 100644 index 000000000000..97fb1dcc2b53 --- /dev/null +++ b/nemo/collections/multimodal/modules/nerf/utils/trt_engine.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import OrderedDict +from copy import copy + +import numpy as np +import tensorrt as trt +import torch +from polygraphy import cuda +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import engine_from_bytes +from polygraphy.backend.trt import util as trt_util + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + + +class Engine: + def __init__( + self, engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def set_engine(self, stream, shape_dict): + self.load() + self.activate() + self.stream = stream + self.allocate_buffers(shape_dict, device='cuda') + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict): + stream = self.stream + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError(f"ERROR: inference failed.") + + return self.tensors + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + + elif schedule == "cosine": + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print( + f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}' + ) + return sigmas, alphas, alphas_prev + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/nemo/collections/multimodal/modules/stable_diffusion/__init__.py b/nemo/collections/multimodal/modules/stable_diffusion/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py new file mode 100644 index 000000000000..b2f211141065 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -0,0 +1,408 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from inspect import isfunction + +import torch +import torch.nn.functional as F +from apex.contrib.group_norm import GroupNorm +from einops import rearrange +from torch import einsum, nn +from torch._dynamo import disable + +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import checkpoint + + +def check_cuda(): + if not torch.cuda.is_available(): + raise RuntimeError('CUDA is not available') + cur_device = torch.cuda.current_device() + dprops = torch.cuda.get_device_properties(cur_device) + + is_sm75 = dprops.major == 7 and dprops.minor == 5 + is_sm8x = dprops.major == 8 and dprops.minor >= 0 + is_sm90 = dprops.major == 9 and dprops.minor >= 0 + + return is_sm8x or is_sm75 or is_sm90 + + +try: + import torch.nn as nn + from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention + + flash_attn_installed = check_cuda() + print("FlashAttention Installed") + + # Disable TorchDynamo on FlashAttention + FlashSelfAttention.forward = disable(FlashSelfAttention.forward) + FlashCrossAttention.forward = disable(FlashCrossAttention.forward) +except ImportError: + flash_attn_installed = False + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + if isinstance(d, (torch.Tensor, float, int)): + return d + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels, num_groups=32, act=""): + return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +# b n (h d) -> (b h) n d +def rearrange_heads_outer(t: torch.Tensor, h: int) -> torch.Tensor: + b, n, ch = t.shape + return t.view(b, n, h, -1).transpose(1, 2).reshape(b * h, n, -1) + + +# (b h) n d -> b n (h d) +def rearrange_heads_inner(t: torch.Tensor, h: int) -> torch.Tensor: + b = t.shape[0] // h + n = t.shape[1] + return t.view(b, h, n, -1).transpose(1, 2).reshape(b, n, -1) + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, use_flash_attention=False): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + # make attention part be aware of self-attention/cross-attention + self.context_dim = context_dim + self.query_dim = query_dim + self.dim_head = dim_head + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.use_flash_attention = use_flash_attention + + if dim_head <= 160 and (dim_head % 8) == 0 and flash_attn_installed: + if context_dim == query_dim: + self.flash_attn = FlashSelfAttention(softmax_scale=self.scale) + else: + self.flash_attn = FlashCrossAttention(softmax_scale=self.scale) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + out = self._attention(q, k, v, mask) + + return self.to_out(out) + + def _attention(self, q, k, v, mask=None): + h = self.heads + + if ( + not flash_attn_installed + or not self.use_flash_attention + or q.dtype == torch.float32 + or (self.dim_head > 160 or (self.dim_head % 8) != 0) + or mask is not None + ): + # original implementation + # b n (h d) -> (b h) n d + q = rearrange_heads_outer(q, h) + k = rearrange_heads_outer(k, h) + v = rearrange_heads_outer(v, h) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + # standard stable diffusion does not run into here + mask = mask.view(mask.shape[0], -1) + b, j = mask.shape + mask = mask.unsqueeze(1).expand(b, h, j).reshape(b * h, 1, j) # b j -> (b h) () j + sim.masked_fill_(~mask, self.max_neg[sim.dtype]) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + + # (b h) n d -> b n (h d) + out = rearrange_heads_inner(out, h) + elif self.context_dim == self.query_dim: + # self-attention + qkv = torch.stack([q, k, v], dim=2) + b, s, t, hd = qkv.shape + d = hd // h + qkv = qkv.view(b, s, t, h, d) + + out = self.flash_attn(qkv) + out = out.view(b, s, hd) + else: + # cross-attention + kv = torch.stack([k, v], dim=2) + + s_q = q.shape[1] + b, s_kv, t, hd = kv.shape + d = hd // h + + q = q.view(b, s_q, h, d) + kv = kv.view(b, s_kv, t, h, d) + + out = self.flash_attn(q, kv) + out = out.view(b, s_q, hd) + + return out + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + use_checkpoint=False, + use_flash_attention=False, + disable_self_attn=False, + ): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + use_flash_attention=use_flash_attention, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + use_flash_attention=use_flash_attention, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.use_checkpoint = use_checkpoint + + def forward(self, x, context=None): + if self.use_checkpoint: + return checkpoint(self._forward, (x, context), self.parameters(), self.use_checkpoint) + else: + return self._forward(x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=False, + use_flash_attention=False, + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + disable_self_attn=disable_self_attn, + ) + for d in range(depth) + ] + ) + + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = x.view(b, c, -1).transpose(1, 2) # b c h w -> b (h w) c + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = x.transpose(1, 2).view(b, c, h, w) # b (h w) c -> b c h w + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/__init__.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py new file mode 100644 index 000000000000..7fc5c208004f --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/model.py @@ -0,0 +1,881 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pytorch_diffusion + derived encoder decoder +import math + +import numpy as np +import torch +import torch.nn as nn +from apex.contrib.group_norm import GroupNorm +from einops import rearrange + +from nemo.collections.multimodal.modules.stable_diffusion.attention import LinearAttention +from nemo.collections.multimodal.parts.stable_diffusion.utils import instantiate_from_config + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return torch.nn.functional.silu(x) + + +def Normalize(in_channels, num_groups=32, act=""): + return GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, act=act) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(yuya): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if dtype == torch.bfloat16: + x = x.to(dtype) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, act="silu") + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, act="silu") + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """ + to match AttnBlock usage + """ + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch),] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1,) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, size=(int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor))) + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d(in_channels, n_channels, kernel_size=3, stride=1, padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in, out_channels=m * n_channels, dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, 'b c h w -> b (h w) c') + return z diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py new file mode 100644 index 000000000000..7f762343c4e9 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -0,0 +1,1175 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from abc import abstractmethod + +import numpy as np +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from nemo.collections.multimodal.modules.stable_diffusion.attention import SpatialTransformer +from nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.util import ( + avg_pool_nd, + checkpoint, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) + + +def convert_module_to_dtype(module, dtype): + # Convert module parameters to dtype + if isinstance(module, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear)): + module.weight.data = module.weight.data.to(dtype) + if module.bias is not None: + module.bias.data = module.bias.data.to(dtype) + + +def convert_module_to_fp16(module): + convert_module_to_dtype(module, torch.float16) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + This layer performs upsampling on the given input with the option to apply a convolution operation. + The upsampling can be applied to 1D, 2D, or 3D signals, depending on the specified dimensions. + + Parameters: + channels (int): The number of channels in both the inputs and outputs. + use_conv (bool): A bool determining if a convolution is applied. + dims (int): Specifies the dimensionality of the signal. + It can be 1, 2, or 3. If set to 3, upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(yuya): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels, self.out_channels, kernel_size=ks, stride=2) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + This layer performs downsampling on the given input and optionally applies a convolution operation. + The downsampling can be applied to 1D, 2D, or 3D signals, with specific behavior for 3D signals. + + Parameters: + channels (int): The number of channels in both the inputs and outputs. + use_conv (bool): Determines whether a convolution is applied. + True to apply convolution, False otherwise. + dims (int): Specifies the dimensionality of the signal. + It can be 1, 2, or 3. For 3D signals, downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that optionally changes the number of channels. + + Parameters: + channels (int): The number of input channels. + emb_channels (int): The number of timestep embedding channels. + dropout (float): The rate of dropout to apply. + out_channels (int, optional): The number of output channels. If not specified, the output channels + will be the same as the input channels. + use_conv (bool): If True and out_channels is specified, a spatial convolution is used instead of a + smaller 1x1 convolution to change the channels in the skip connection. + dims (int): Determines if the signal is 1D, 2D, or 3D. + use_checkpoint (bool): If True, gradient checkpointing is used on this module. This can save memory + at the cost of additional compute. + up (bool): If True, the block is used for upsampling. + down (bool): If True, the block is used for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + resblock_gn_groups=32, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels, act="silu", gn_groups=resblock_gn_groups), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels, act="silu", gn_groups=resblock_gn_groups), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + Parameters: + x (Tensor): An input Tensor of shape [N x C x ...], where N is the batch size, C is the number of channels, + and '...' represents additional dimensions. + emb (Tensor): A Tensor of timestep embeddings of shape [N x emb_channels], where emb_channels is the number + of embedding channels. + + Returns: + Tensor: An output Tensor of shape [N x C x ...], representing the processed features. + """ + if self.use_checkpoint: + return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) + else: + return self._forward(x, emb) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV (Query-Key-Value) attention. + + Parameters: + qkv (Tensor): An input tensor of shape [N x (3 * H * C) x T], where N is the batch size, + H is the number of attention heads, C is the channel size, and T is the sequence length. + This tensor includes queries, keys, and values concatenated together. + + Returns: + Tensor: An output tensor of shape [N x (H * C) x T] after applying attention. This tensor + contains the processed information with the same sequence length but with modified features. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + Parameters: + in_channels (int): The number of channels in the input Tensor. + model_channels (int): The base channel count for the model. + out_channels (int): The number of channels in the output Tensor. + num_res_blocks (int): The number of residual blocks per downsample. + attention_resolutions (set/list/tuple): The downsampling rates at which attention is applied. + For example, if this includes 4, attention is used at 4x downsampling. + dropout (float): The dropout probability. + channel_mult (list/tuple): A channel multiplier for each level of the UNet. + conv_resample (bool): If True, use learned convolutions for upsampling and downsampling. + dims (int): Determines if the signal is 1D, 2D, or 3D. + num_classes (int, optional): If specified, the model becomes class-conditional with the given number of classes. + use_checkpoint (bool): If True, use gradient checkpointing to reduce memory usage. + num_heads (int): The number of attention heads in each attention layer. + num_heads_channels (int, optional): If specified, overrides num_heads and uses a fixed channel width per attention head. + num_heads_upsample (int, optional): Sets a different number of heads for upsampling. Deprecated. + use_scale_shift_norm (bool): If True, use a FiLM-like conditioning mechanism. + resblock_updown (bool): If True, use residual blocks for up/downsampling. + use_new_attention_order (bool): If True, use a different attention pattern for potentially increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + resblock_gn_groups=32, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + use_linear_in_transformer=False, + from_pretrained: str = None, + from_NeMo=False, + # It must be specified when from pretrained is not None. It indicates loading unet from NeMo trained ckpt or HF + use_flash_attention: bool = False, + enable_amp_o2_fp16: bool = False, + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), 'You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert ( + use_spatial_transformer + ), 'You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + resblock_gn_groups=resblock_gn_groups, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + resblock_gn_groups=resblock_gn_groups, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch, act="silu", gn_groups=resblock_gn_groups), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + if from_pretrained is not None: + state_dict = torch.load(from_pretrained, map_location='cpu') + if 'state_dict' in state_dict.keys(): + state_dict = state_dict['state_dict'] + missing_key, unexpected_keys, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo) + if len(missing_key) > 0: + print( + 'Following keys are missing during loading unet weights, which may lead to compromised image quality for a resumed training. Please check the checkpoint you provided.' + ) + print(f"Missing keys: {missing_key}") + print(f"Unexpected keys: {unexpected_keys}") + + if enable_amp_o2_fp16: + self.convert_to_fp16() + + def _input_blocks_mapping(self, input_dict): + res_dict = {} + for key_, value_ in input_dict.items(): + id_0 = int(key_[13]) + if "resnets" in key_: + id_1 = int(key_[23]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = ( + key_[25:] + .replace('time_emb_proj', 'emb_layers.1') + .replace('norm1', 'in_layers.0') + .replace('norm2', 'out_layers.0') + .replace('conv1', 'in_layers.2') + .replace('conv2', 'out_layers.3') + .replace('conv_shortcut', 'skip_connection') + ) + res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[26]) + target_id = 3 * id_0 + 1 + id_1 + post_fix = key_[28:] + res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "downsamplers" in key_: + post_fix = key_[35:] + target_id = 3 * (id_0 + 1) + res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_ + return res_dict + + def _mid_blocks_mapping(self, mid_dict): + res_dict = {} + for key_, value_ in mid_dict.items(): + if "resnets" in key_: + temp_key_ = ( + key_.replace('time_emb_proj', 'emb_layers.1') + .replace('norm1', 'in_layers.0') + .replace('norm2', 'out_layers.0') + .replace('conv1', 'in_layers.2') + .replace('conv2', 'out_layers.3') + .replace('conv_shortcut', 'skip_connection') + .replace('middle_block.resnets.0', 'middle_block.0') + .replace('middle_block.resnets.1', 'middle_block.2') + ) + res_dict[temp_key_] = value_ + elif "attentions" in key_: + res_dict[key_.replace('attentions.0', '1')] = value_ + return res_dict + + def _other_blocks_mapping(self, other_dict): + res_dict = {} + for key_, value_ in other_dict.items(): + tmp_key = ( + key_.replace('conv_in', 'input_blocks.0.0') + .replace('time_embedding.linear_1', 'time_embed.0') + .replace('time_embedding.linear_2', 'time_embed.2') + .replace('conv_norm_out', 'out.0') + .replace('conv_out', 'out.2') + ) + res_dict[tmp_key] = value_ + return res_dict + + def _output_blocks_mapping(self, output_dict): + res_dict = {} + for key_, value_ in output_dict.items(): + id_0 = int(key_[14]) + if "resnets" in key_: + id_1 = int(key_[24]) + target_id = 3 * id_0 + id_1 + post_fix = ( + key_[26:] + .replace('time_emb_proj', 'emb_layers.1') + .replace('norm1', 'in_layers.0') + .replace('norm2', 'out_layers.0') + .replace('conv1', 'in_layers.2') + .replace('conv2', 'out_layers.3') + .replace('conv_shortcut', 'skip_connection') + ) + res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_ + elif "attentions" in key_: + id_1 = int(key_[27]) + target_id = 3 * id_0 + id_1 + post_fix = key_[29:] + res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_ + elif "upsamplers" in key_: + post_fix = key_[34:] + target_id = 3 * (id_0 + 1) - 1 + mid_str = '.2.conv.' if target_id != 2 else '.1.conv.' + res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_ + return res_dict + + def _state_key_mapping(self, state_dict: dict): + + res_dict = {} + input_dict = {} + mid_dict = {} + output_dict = {} + other_dict = {} + for key_, value_ in state_dict.items(): + if "down_blocks" in key_: + input_dict[key_.replace('down_blocks', 'input_blocks')] = value_ + elif "up_blocks" in key_: + output_dict[key_.replace('up_blocks', 'output_blocks')] = value_ + elif "mid_block" in key_: + mid_dict[key_.replace('mid_block', 'middle_block')] = value_ + else: + other_dict[key_] = value_ + + input_dict = self._input_blocks_mapping(input_dict) + output_dict = self._output_blocks_mapping(output_dict) + mid_dict = self._mid_blocks_mapping(mid_dict) + other_dict = self._other_blocks_mapping(other_dict) + res_dict.update(input_dict) + res_dict.update(output_dict) + res_dict.update(mid_dict) + res_dict.update(other_dict) + + return res_dict + + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): + if from_NeMo: + state_dict = self._strip_unet_key_prefix(state_dict) + else: + state_dict = self._state_key_mapping(state_dict) + model_state_dict = self.state_dict() + loaded_keys = [k for k in state_dict.keys()] + expected_keys = list(model_state_dict.keys()) + original_loaded_keys = loaded_keys + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + if ( + 'input_blocks.1.0.in_layers.2.weight' in loaded_keys + and 'input_blocks.1.0.in_layers.1.weight' in expected_keys + ): + # GroupNormOpt fuses activation function to one layer, thus the indexing of weights are shifted for following + for key_ in missing_keys: + s = key_.split('.') + idx = int(s[-2]) + new_key_ = ".".join(s[:-2] + [str(int(idx + 1))] + [s[-1]]) + state_dict[key_] = state_dict[new_key_] + + loaded_keys = list(state_dict.keys()) + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + def _find_mismatched_keys( + state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + ) + error_msgs = self._load_state_dict_into_model(state_dict) + return missing_keys, unexpected_keys, mismatched_keys, error_msgs + + # TODO MMY maybe combine these cases of key prefix + def _strip_unet_key_prefix(self, state_dict): + re_state_dict = {} + for key_, value_ in state_dict.items(): + if key_.startswith('model.diffusion_model'): + re_state_dict[key_.replace('model.diffusion_model.', '')] = value_ + if key_.startswith('model.model.diffusion_model'): + re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_ + if key_.startswith('model._orig_mod.diffusion_model.'): + re_state_dict[key_.replace('model._orig_mod.diffusion_model.', '')] = value_ + if key_.startswith('model.model._orig_mod.diffusion_model.'): + re_state_dict[key_.replace('model.model._orig_mod.diffusion_model.', '')] = value_ + if key_.startswith('model.model.diffusion_model._orig_mod.'): + re_state_dict[key_.replace('model.model.diffusion_model._orig_mod.', '')] = value_ + return re_state_dict + + def _load_state_dict_into_model(self, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(self) + + return error_msgs + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.apply(convert_module_to_fp16) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + + Parameters: + x (Tensor): An input tensor of shape [N x C x ...], where N is the batch size, C is the number of channels, + and '...' represents additional dimensions. + timesteps (Tensor): A 1-D tensor representing a batch of timesteps. + context (Tensor, optional): An optional tensor for additional conditioning, used via cross-attention. + y (Tensor, optional): An optional 1-D tensor of labels of shape [N], used if the model is class-conditional. + + Returns: + Tensor: An output tensor of shape [N x C x ...], representing the processed batch. + """ + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(emb.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + resblock_gn_groups=32, + *args, + **kwargs, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + resblock_gn_groups=resblock_gn_groups, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + resblock_gn_groups=resblock_gn_groups, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), nn.SiLU(), AttentionPool2d((image_size // ds), ch, num_head_channels, out_channels), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_fp16) + self.middle_block.apply(convert_module_to_fp16) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels), use_fp16=self.use_fp16) + + # future support + if self.dtype == th.float32: + self.dtype == x.dtype + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py new file mode 100644 index 000000000000..695333edc649 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/util.py @@ -0,0 +1,319 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +''' +adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +''' + +import math + +import numpy as np +import torch +import torch.nn as nn +from apex.contrib.group_norm import GroupNorm +from einops import repeat +from torch._dynamo import disable +from torch.cuda.amp import custom_bwd, custom_fwd + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + + elif schedule == "cosine": + timesteps = torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + variance = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + sigmas = eta * np.sqrt(variance) + if verbose: + print(f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}") + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev, variance + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule based on a discretized alpha_t_bar function. + + Parameters: + num_diffusion_timesteps (int): The number of beta values to produce, corresponding to the number of timesteps in the diffusion process. + alpha_bar (function): A lambda function that accepts a time value t ranging from 0 to 1 and returns the cumulative product of (1-beta) up to that point in the diffusion process. + max_beta (float): The maximum allowable value for beta. Setting this to a value lower than 1 helps in preventing singularities in the diffusion process. + + Returns: + list: A list of beta values that correspond to each timestep in the diffusion process. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations. + + Parameters: + func (function): The function to be evaluated. This should be a callable object. + inputs (sequence): The arguments to pass to `func`. This is a sequence of inputs that `func` will be called with. + params (sequence): A sequence of parameters that `func` depends on but does not explicitly take as arguments. + These are additional parameters required by `func`. + flag (bool): If set to False, disables gradient checkpointing. If True, enables gradient checkpointing which + allows for memory savings at the cost of extra compute during the backward pass. + + Returns: + The result of evaluating `func` with the given inputs and parameters, with reduced memory usage during the forward pass. + """ + + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @custom_bwd + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +# Temporary hack to get rid of TorchDynamo issue with DDP +# TODO: remove this if https://github.com/pytorch/pytorch/issues/94574 fixed +@disable +def get_idx(end, device): + return torch.arange(start=0, end=end, dtype=torch.float32, device=device) + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + + Parameters: + timesteps (Tensor): A 1-D tensor of N indices, one per batch element. These indices may be fractional and + represent the timesteps for which embeddings are to be created. + dim (int): The dimension of the output embeddings. Each timestep will be represented as a vector of this dimension. + max_period (float): Controls the minimum frequency of the embeddings. Higher values result in higher frequency + components in the embedding. + + Returns: + Tensor: An [N x dim] tensor of positional embeddings, where each row corresponds to the embedding for a timestep. + """ + + if not repeat_only: + half = dim // 2 + idx = get_idx(half, timesteps.device) + freqs = torch.exp(-math.log(max_period) / half * idx) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(in_channels, act="", gn_groups=32): + return GroupNorm(num_groups=gn_groups, num_channels=in_channels, eps=1e-5, affine=True, act=act) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where(torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] diff --git a/nemo/collections/multimodal/modules/stable_diffusion/distributions/__init__.py b/nemo/collections/multimodal/modules/stable_diffusion/distributions/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/distributions/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/distributions/distributions.py b/nemo/collections/multimodal/modules/stable_diffusion/distributions/distributions.py new file mode 100644 index 000000000000..81d79ac5801a --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/distributions/distributions.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] + + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/encoders/__init__.py b/nemo/collections/multimodal/modules/stable_diffusion/encoders/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/encoders/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py new file mode 100644 index 000000000000..ec3bd82ba137 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py @@ -0,0 +1,470 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import tempfile +from functools import partial + +import open_clip +import torch +import torch.nn as nn +from omegaconf import OmegaConf +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTextModel, CLIPTokenizer + +from nemo.collections.multimodal.data.clip.clip_dataset import get_preprocess_fns +from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import CLIPModel +from nemo.collections.multimodal.modules.stable_diffusion.encoders.x_transformer import ( + TransformerWrapper, # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +) +from nemo.collections.multimodal.modules.stable_diffusion.encoders.x_transformer import Encoder +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import logging + +try: + from megatron.core import ModelParallelConfig, parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + ModelParallelConfig = ApexGuardDefaults + + HAVE_MEGATRON_CORE = False + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper( + num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=Encoder(dim=n_embed, depth=n_layer) + ) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + + def __init__( + self, + n_embed, + n_layer, + vocab_size=30522, + max_seq_len=77, + device="cuda", + use_tokenizer=True, + embedding_dropout=0.0, + ): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper( + num_tokens=vocab_size, + max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout, + ) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text) # .to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, n_stages=1, method='bilinear', multiplier=0.5, in_channels=3, out_channels=None, bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias) + + def forward(self, x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, capture_cudagraph_iters: int = -1 + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + # CUDA graph captured sub-modules + self.capture_cudagraph_iters = capture_cudagraph_iters + self.iterations = 0 + self.stream = torch.cuda.Stream() + self.transformer_graph = torch.cuda.CUDAGraph() + self.static_tokens = None + self.static_outputs = None + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + if self.capture_cudagraph_iters < 0: + tokens = batch_encoding["input_ids"].to(self.device, non_blocking=True) + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + + else: + if self.static_tokens is None: + self.static_tokens = batch_encoding["input_ids"].to(device=self.device, non_blocking=True) + self.static_tokens.copy_(batch_encoding["input_ids"], non_blocking=True) + + if self.iterations == self.capture_cudagraph_iters: + # cuda graph capture + logging.info("Capturing CUDA graph for module: %s", self.transformer.__class__.__name__) + with torch.cuda.graph(self.transformer_graph): + self.static_outputs = self.transformer(input_ids=self.static_tokens) + + if 0 <= self.capture_cudagraph_iters <= self.iterations: + # cuda graph replay + self.transformer_graph.replay() + else: + # warmup + self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + self.static_outputs = self.transformer(input_ids=self.static_tokens) + torch.cuda.current_stream().wait_stream(self.stream) + self.iterations += 1 + z = self.static_outputs.last_hidden_state + + # # Pad the seq length to multiple of 8 + seq_len = (z.shape[1] + 8 - 1) // 8 * 8 + z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0) + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + use_fp16=False, + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenMegatronCLIPEmbedder(AbstractEncoder): + def __init__(self, restore_from_path, device="cuda", layer="last", freeze=True, cfg=None, use_fp16=False): + super().__init__() + if restore_from_path is not None: + cfg, state_dict = self.load_config_and_state_from_nemo(restore_from_path) + elif cfg is not None: + state_dict = None + else: + raise ValueError("Either restore_from_path or cfg should not be None") + + self.cfg = cfg + self.build_tokenizer(cfg) + self.load_model(cfg, state_dict) + + self.device = device + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def load_config_and_state_from_nemo(self, nemo_path): + if torch.cuda.is_available(): + map_location = torch.device('cuda') + else: + map_location = torch.device('cpu') + save_restore_connector = NLPSaveRestoreConnector() + cwd = os.getcwd() + + with tempfile.TemporaryDirectory() as tmpdir: + try: + save_restore_connector._unpack_nemo_file(path2file=nemo_path, out_folder=tmpdir) + + # Change current working directory to + os.chdir(tmpdir) + config_yaml = os.path.join(tmpdir, save_restore_connector.model_config_yaml) + cfg = OmegaConf.load(config_yaml) + + model_weights = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt) + state_dict = save_restore_connector._load_state_dict_from_disk( + model_weights, map_location=map_location + ) + finally: + os.chdir(cwd) + + return cfg, state_dict + + def build_tokenizer(self, cfg): + legacy = cfg.tokenizer.sentencepiece_legacy + self.tokenizer = get_nmt_tokenizer( + library=cfg.tokenizer.library, + model_name=cfg.tokenizer.type, + tokenizer_model=cfg.tokenizer.model, + vocab_file=cfg.tokenizer.vocab_file, + merges_file=cfg.tokenizer.merge_file, + delimiter=cfg.tokenizer.get('delimiter', None), + legacy=legacy, + ) + + _, self.text_transform = get_preprocess_fns(cfg, self.tokenizer, is_train=False,) + self.max_length = cfg.text.get("max_position_embeddings") + + def load_model(self, cfg, state_dict): + padded_vocab_size = self._vocab_size_with_padding( + orig_vocab_size=self.tokenizer.vocab_size, + make_vocab_size_divisible_by=cfg.get('make_vocab_size_divisible_by', 128), + tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), + ) + model = CLIPModel( + model_cfg=cfg, + model_parallel_config=ModelParallelConfig(), + padded_vocab_size=padded_vocab_size, + pre_process=cfg.text.pre_process, + post_process=cfg.text.post_process, + ) + + if state_dict is not None: + clip_state_dict = {} + for key, value in state_dict.items(): + key = key[6:] + clip_state_dict[key] = value + model.load_state_dict(clip_state_dict) + + del model.vision_encoder + self.model = model.text_encoder + + def _vocab_size_with_padding(self, orig_vocab_size, make_vocab_size_divisible_by, tensor_model_parallel_size): + after = orig_vocab_size + multiple = make_vocab_size_divisible_by * tensor_model_parallel_size + while (after % multiple) != 0: + after += 1 + return after + + def forward(self, text): + ''' + Get embeddings from input text + ''' + texts = self.text_transform(text) + z = self.encode_with_transformer(texts.to(self.device)) + # # Pad the seq length to multiple of 8 + seq_len = (z.shape[1] + 8 - 1) // 8 * 8 + z = torch.nn.functional.pad(z, (0, 0, 0, seq_len - z.shape[1]), value=0.0) + return z + + def encode_with_transformer(self, text): + x = self.model.language_model.embedding.word_embeddings(text) + x += self.model.language_model.embedding.position_embeddings + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = self.model.language_model.encoder.final_layernorm(x) + x = x.permute(1, 0, 2) # LND -> NLD + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.language_model.encoder.layers): + if i == len(self.model.language_model.encoder.layers) - self.layer_idx: + break + x = r(x, attn_mask) + return x + + def encode(self, text): + return self(text) + + +if __name__ == "__main__": + from ldm.util import count_params + + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/encoders/x_transformer.py b/nemo/collections/multimodal/modules/stable_diffusion/encoders/x_transformer.py new file mode 100644 index 000000000000..938352817190 --- /dev/null +++ b/nemo/collections/multimodal/modules/stable_diffusion/encoders/x_transformer.py @@ -0,0 +1,630 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""adopted from https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +from collections import namedtuple +from functools import partial +from inspect import isfunction + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import einsum, nn + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', ['pre_softmax_attn', 'post_softmax_attn']) + +LayerIntermediates = namedtuple('Intermediates', ['hiddens', 'attn_intermediates']) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + + return inner + + +def not_equals(val): + def inner(x): + return x != val + + return inner + + +def equals(val): + def inner(x): + return x == val + + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru(rearrange(x, 'b n d -> (b n) d'), rearrange(residual, 'b n d -> (b n) d')) + + return gated_output.reshape_as(x) + + +# feedforward + + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + # self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None, + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates(pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs, + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) + + def forward(self, x, context=None, mask=None, context_mask=None, mems=None, return_hiddens=False): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block( + x, + mask=mask, + sinusoidal_emb=self.pia_pos_emb, + rel_pos=self.rel_pos, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates(hiddens=hiddens, attn_intermediates=intermediates) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0.0, + emb_dropout=0.0, + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True, + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = ( + AbsolutePositionalEmbedding(emb_dim, max_seq_len) + if (use_pos_emb and not attn_layers.has_pos_emb) + else always(0) + ) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, x, return_embeddings=False, mask=None, return_mems=False, return_attn=False, mems=None, **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out diff --git a/nemo/collections/multimodal/parts/imagen/__init__.py b/nemo/collections/multimodal/parts/imagen/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/parts/imagen/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/parts/imagen/utils.py b/nemo/collections/multimodal/parts/imagen/utils.py new file mode 100644 index 000000000000..565b1ed6a2b4 --- /dev/null +++ b/nemo/collections/multimodal/parts/imagen/utils.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def random_dropout(embeddings, drop_rate): + r""" + Function to perform random dropout for embeddings. + When we drop embeddings, we zero them out. + Args: + embeddings (tensor): Input embeddings + drop_rate (float): Rate of dropping the embedding. + """ + nsamples = embeddings.shape[0] + zero_flag = torch.ones(nsamples, 1, 1).to(embeddings.dtype) * (1 - drop_rate) + zero_flag = torch.bernoulli(zero_flag).cuda() + embeddings = embeddings * zero_flag + return embeddings diff --git a/nemo/collections/multimodal/parts/stable_diffusion/__init__.py b/nemo/collections/multimodal/parts/stable_diffusion/__init__.py new file mode 100644 index 000000000000..4fc50543f1d2 --- /dev/null +++ b/nemo/collections/multimodal/parts/stable_diffusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/multimodal/parts/stable_diffusion/pipeline.py b/nemo/collections/multimodal/parts/stable_diffusion/pipeline.py new file mode 100644 index 000000000000..e9de61d6025a --- /dev/null +++ b/nemo/collections/multimodal/parts/stable_diffusion/pipeline.py @@ -0,0 +1,202 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pickle +import time + +import torch +from PIL import Image + +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.ddim import DDIMSampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.para_ddim import ParaDDIMSampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.plms import PLMSSampler +from nemo.collections.multimodal.models.text_to_image.stable_diffusion.samplers.sampler_dpm import DPMSolverSampler +from nemo.collections.multimodal.parts.stable_diffusion.utils import DataParallelWrapper + + +def encode_prompt(cond_stage_model, prompt, unconditional_guidance_scale, batch_size): + c = cond_stage_model.encode(batch_size * [prompt]) + if unconditional_guidance_scale != 1.0: + uc = cond_stage_model.encode(batch_size * [""]) + else: + uc = None + return c, uc + + +def initialize_sampler(model, sampler_type): + if sampler_type == 'DDIM': + sampler = DDIMSampler(model) + elif sampler_type == 'PLMS': + sampler = PLMSSampler(model) + elif sampler_type == 'DPM': + sampler = DPMSolverSampler(model) + elif sampler_type == 'PARA_DDIM': + sampler = ParaDDIMSampler(model) + else: + raise ValueError(f'Sampler {sampler_type} is not supported.') + return sampler + + +def decode_images(model, samples): + images = model.decode_first_stage(samples) + + images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) + + return images + + +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +def torch_to_numpy(images): + numpy_images = [x.float().cpu().permute(0, 2, 3, 1).numpy() for x in images] + return numpy_images + + +def pipeline(model, cfg, verbose=True, rng=None): + # setup default values for inference configs + unconditional_guidance_scale = cfg.infer.get("unconditional_guidance_scale", 7.5) + batch_size = cfg.infer.get('num_images_per_prompt', 1) + prompts = cfg.infer.get('prompts', []) + height = cfg.infer.get('height', 512) + width = cfg.infer.get('width', 512) + downsampling_factor = cfg.infer.get('down_factor', 8) + sampler_type = cfg.infer.get('sampler_type', 'DDIM') + sampler_parallelism = cfg.infer.get('sampler_parallelism', 1) + sampler_tolerance = cfg.infer.get('sampler_tolerance', 0.1) + inference_steps = cfg.infer.get('inference_steps', 50) + output_type = cfg.infer.get('output_type', 'pil') + save_to_file = cfg.infer.get('save_to_file', True) + out_path = cfg.infer.get('out_path', '') + eta = cfg.infer.get('eta', 0) + num_devices = cfg.infer.get('devices', 1) + + if sampler_parallelism > 1: + if not sampler_type.startswith('PARA'): + raise ValueError('Parallel sampler is required when parallelism > 1') + if not num_devices > 1: + print("It is recommended to run parallel sampler with multiple GPUs") + + if num_devices > 1: + print(f"Running DataParallel model with {num_devices} GPUs.") + model.model.diffusion_model = DataParallelWrapper( + model.model.diffusion_model, device_ids=list(range(num_devices)) + ) + + # get autocast_dtype + if cfg.trainer.precision in ['bf16', 'bf16-mixed']: + autocast_dtype = torch.bfloat16 + elif cfg.trainer.precision in [32, '32', '32-true']: + autocast_dtype = torch.float + elif cfg.trainer.precision in [16, '16', '16-mixed']: + autocast_dtype = torch.half + else: + raise ValueError('precision must be in [32, 16, "bf16"]') + + with torch.no_grad(), torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + ): + + in_channels = model.model.diffusion_model.in_channels + + sampler = initialize_sampler(model, sampler_type.upper()) + + output = [] + throughput = [] + + if isinstance(prompts, str): + prompts = [prompts] + + for prompt in prompts: + tic = time.perf_counter() + tic_total = tic + cond, u_cond = encode_prompt(model.cond_stage_model, prompt, unconditional_guidance_scale, batch_size) + toc = time.perf_counter() + conditioning_time = toc - tic + + latent_shape = [in_channels, height // downsampling_factor, width // downsampling_factor] + latents = torch.randn( + [batch_size, in_channels, height // downsampling_factor, width // downsampling_factor], generator=rng + ).to(torch.cuda.current_device()) + + tic = time.perf_counter() + samples, intermediates = sampler.sample( + S=inference_steps, + conditioning=cond, + batch_size=batch_size, + shape=latent_shape, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=u_cond, + eta=eta, + x_T=latents, + parallelism=sampler_parallelism, + tolerance=sampler_tolerance, + ) + toc = time.perf_counter() + sampling_time = toc - tic + + tic = time.perf_counter() + images = decode_images(model, samples) + toc = time.perf_counter() + decode_time = toc - tic + + toc_total = time.perf_counter() + total_time = toc_total - tic_total + output.append(images) + + throughput.append( + { + 'text-conditioning-time': conditioning_time, + 'sampling-time': sampling_time, + 'decode-time': decode_time, + 'total-time': total_time, + 'sampling-steps': inference_steps, + } + ) + + # Convert output type and save to disk + if output_type == 'torch': + output = torch.cat(output, dim=0) + else: + output = torch_to_numpy(output) + if output_type == 'pil': + output = [numpy_to_pil(x) for x in output] + + if save_to_file: + os.makedirs(out_path, exist_ok=True) + if output_type == 'pil': + for text_prompt, pils in zip(prompts, output): + for idx, image in enumerate(pils): + image.save(os.path.join(out_path, f'{text_prompt[:50]}_{idx}.png')) + else: + with open(os.path.join(out_path, 'output.pkl'), 'wb') as f: + pickle.dump(output, f) + else: + return output + + ave_metrics = {} + for key in throughput[0].keys(): + ave_metrics[f'avg-{key}'] = sum([dicts[key] for dicts in throughput]) / len(throughput) + if verbose: + print(ave_metrics) diff --git a/nemo/collections/multimodal/parts/stable_diffusion/utils.py b/nemo/collections/multimodal/parts/stable_diffusion/utils.py new file mode 100644 index 000000000000..3e6697747c13 --- /dev/null +++ b/nemo/collections/multimodal/parts/stable_diffusion/utils.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import multiprocessing as mp +from collections import abc +from inspect import isfunction +from queue import Queue +from threading import Thread + +import numpy as np +import torch +from PIL import Image, ImageDraw +from nemo.utils import logging + + +class DataParallelWrapper(torch.nn.DataParallel): + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black") + except UnicodeEncodeError: + logging("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + logging(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + logging(f'Getting module=<{module}>, cls=<{cls}>') + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + logging( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))] + else: + step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)]) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + logging(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + logging("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + logging(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/nemo/collections/tts/models/spectrogram_enhancer.py b/nemo/collections/tts/models/spectrogram_enhancer.py index ca2fe6122230..7115360e7125 100644 --- a/nemo/collections/tts/models/spectrogram_enhancer.py +++ b/nemo/collections/tts/models/spectrogram_enhancer.py @@ -41,7 +41,6 @@ import torch import torch.nn.functional as F -import torchvision from einops import rearrange from hydra.utils import instantiate from omegaconf import DictConfig @@ -61,6 +60,13 @@ from nemo.core.neural_types.elements import BoolType from nemo.utils import logging +try: + import torchvision + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + class SpectrogramEnhancerModel(ModelPT, Exportable): """ @@ -318,6 +324,7 @@ def log_illustration(self, target_spectrograms, input_spectrograms, enhanced_spe dim=0, ).cpu()[:, idx, :, :, :length] + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." grid = torchvision.utils.make_grid(tensor, nrow=1).clamp(0.0, 1.0) for logger in self.loggers: diff --git a/nemo/collections/vision/data/megatron/image_folder.py b/nemo/collections/vision/data/megatron/image_folder.py index 8cd30a74dd56..f6dbe7dde513 100644 --- a/nemo/collections/vision/data/megatron/image_folder.py +++ b/nemo/collections/vision/data/megatron/image_folder.py @@ -20,7 +20,14 @@ import numpy as np from PIL import Image -from torchvision.datasets import VisionDataset +from nemo.utils import logging + +try: + from torchvision.datasets import VisionDataset + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: @@ -96,117 +103,129 @@ def is_valid_file(x: str) -> bool: return instances -class DatasetFolder(VisionDataset): - """A generic data loader where the samples are arranged in this way: :: - root/class_x/xxx.ext - root/class_x/xxy.ext - root/class_x/[...]/xxz.ext - root/class_y/123.ext - root/class_y/nsdf3.ext - root/class_y/[...]/asd932_.ext - Args: - root (string): Root directory path. - loader (callable): A function to load a sample given its path. - extensions (tuple[string]): A list of allowed extensions. - both extensions and is_valid_file should not be passed. - transform (callable, optional): A function/transform that takes in - a sample and returns a transformed version. - E.g, ``transforms.RandomCrop`` for images. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. - is_valid_file (callable, optional): A function that takes path of a file - and check if the file is a valid file (used to check of corrupt files) - both extensions and is_valid_file should not be passed. - Attributes: - classes (list): List of the class names sorted alphabetically. - class_to_idx (dict): Dict with items (class_name, class_index). - samples (list): List of (sample path, class_index) tuples - targets (list): The class_index value for each image in the dataset - """ - - def __init__( - self, - root: str, - loader: Callable[[str], Any], - extensions: Optional[Tuple[str, ...]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - classes_fraction=1.0, - data_per_class_fraction=1.0, - is_valid_file: Optional[Callable[[str], bool]] = None, - ) -> None: - super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) - self.classes_fraction = classes_fraction - self.data_per_class_fraction = data_per_class_fraction - classes, class_to_idx = self._find_classes(self.root) - samples = self.make_dataset(self.root, class_to_idx, self.data_per_class_fraction, extensions, is_valid_file) - if len(samples) == 0: - msg = "Found 0 files in subfolders of: {}\n".format(self.root) - if extensions is not None: - msg += "Supported extensions are: {}".format(",".join(extensions)) - raise RuntimeError(msg) - - self.loader = loader - self.extensions = extensions - self.total = len(samples) - self.classes = classes - self.class_to_idx = class_to_idx - self.samples = samples - self.targets = [s[1] for s in samples] - - @staticmethod - def make_dataset( - directory: str, - class_to_idx: Dict[str, int], - data_per_class_fraction: float, - extensions: Optional[Tuple[str, ...]] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, - ) -> List[Tuple[str, int]]: - return make_dataset( - directory, class_to_idx, data_per_class_fraction, extensions=extensions, is_valid_file=is_valid_file - ) +if TORCHVISION_AVAILABLE: - def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: - """ - Finds the class folders in a dataset. + class DatasetFolder(VisionDataset): + """A generic data loader where the samples are arranged in this way: :: + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/[...]/xxz.ext + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/[...]/asd932_.ext Args: - dir (string): Root directory path. - Returns: - tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. - Ensures: - No class is a subdirectory of another. + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (tuple[string]): A list of allowed extensions. + both extensions and is_valid_file should not be passed. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + is_valid_file (callable, optional): A function that takes path of a file + and check if the file is a valid file (used to check of corrupt files) + both extensions and is_valid_file should not be passed. + Attributes: + classes (list): List of the class names sorted alphabetically. + class_to_idx (dict): Dict with items (class_name, class_index). + samples (list): List of (sample path, class_index) tuples + targets (list): The class_index value for each image in the dataset """ - all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] - classes = all_classes[0 : int(len(all_classes) * self.classes_fraction)] - classes.sort() - class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} - return classes, class_to_idx - def __getitem__(self, index: int) -> Tuple[Any, Any]: - """ - Args: - index (int): Index - Returns: - tuple: (sample, target) where target is class_index of the target class. - """ - curr_index = index - for x in range(self.total): - try: - path, target = self.samples[curr_index] - sample = self.loader(path) - break - except Exception as e: - curr_index = np.random.randint(0, self.total) - - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - def __len__(self) -> int: - return len(self.samples) + def __init__( + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + classes_fraction=1.0, + data_per_class_fraction=1.0, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> None: + super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) + self.classes_fraction = classes_fraction + self.data_per_class_fraction = data_per_class_fraction + classes, class_to_idx = self._find_classes(self.root) + samples = self.make_dataset( + self.root, class_to_idx, self.data_per_class_fraction, extensions, is_valid_file + ) + if len(samples) == 0: + msg = "Found 0 files in subfolders of: {}\n".format(self.root) + if extensions is not None: + msg += "Supported extensions are: {}".format(",".join(extensions)) + raise RuntimeError(msg) + + self.loader = loader + self.extensions = extensions + self.total = len(samples) + self.classes = classes + self.class_to_idx = class_to_idx + self.samples = samples + self.targets = [s[1] for s in samples] + + @staticmethod + def make_dataset( + directory: str, + class_to_idx: Dict[str, int], + data_per_class_fraction: float, + extensions: Optional[Tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> List[Tuple[str, int]]: + return make_dataset( + directory, class_to_idx, data_per_class_fraction, extensions=extensions, is_valid_file=is_valid_file + ) + + def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]: + """ + Finds the class folders in a dataset. + Args: + dir (string): Root directory path. + Returns: + tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. + Ensures: + No class is a subdirectory of another. + """ + all_classes = [d.name for d in os.scandir(dir) if d.is_dir()] + classes = all_classes[0 : int(len(all_classes) * self.classes_fraction)] + classes.sort() + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + curr_index = index + for x in range(self.total): + try: + path, target = self.samples[curr_index] + sample = self.loader(path) + break + except Exception as e: + curr_index = np.random.randint(0, self.total) + + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self) -> int: + return len(self.samples) + + +else: + + class DatasetFolder: + def __init__(self): + super().__init__() + logging.error("Torchvision not found but required.") IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') diff --git a/nemo/collections/vision/data/megatron/vit_dataset.py b/nemo/collections/vision/data/megatron/vit_dataset.py index 5ba711dd0b28..44f09619c094 100644 --- a/nemo/collections/vision/data/megatron/vit_dataset.py +++ b/nemo/collections/vision/data/megatron/vit_dataset.py @@ -16,7 +16,6 @@ import numpy as np import torch -import torchvision.transforms as T from PIL import Image, ImageFilter, ImageOps from torch.utils.data import Dataset @@ -24,6 +23,13 @@ from nemo.collections.vision.data.megatron.autoaugment import ImageNetPolicy from nemo.collections.vision.data.megatron.image_folder import ImageFolder +try: + import torchvision.transforms as T + + TORCHVISION_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + TORCHVISION_AVAILABLE = False + def _to_torch_data_type(precision): if precision in ['bf16', 'bf16-mixed']: @@ -92,6 +98,7 @@ def __call__(self, img): class ClassificationTransform: def __init__(self, model_cfg, image_size, train=True): self.data_type = _to_torch_data_type(model_cfg.precision) + assert TORCHVISION_AVAILABLE, "Torchvision imports failed but they are required." if train: self.transform = T.Compose( [