Skip to content

Fix Group offloading behaviour when using streams #11097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 18, 2025

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Mar 18, 2025

Fixes #11041, #11086

Requires #11094 to be merged first (also see it for more details).

code
import torch
from diffusers import AutoencoderKLWan, WanPipeline, UniPCMultistepScheduler
from diffusers.hooks import apply_group_offloading, ModelHook, HookRegistry
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

class NanDetectorHook(ModelHook):
    def __init__(self, name):
        super().__init__()
        self.name = name
    
    def post_forward(self, module, output):
        if torch.is_tensor(output):
            if torch.isnan(output).any():
                print(f"NaN detected in {self.name}")
                # breakpoint()
        elif isinstance(output, (list, tuple)):
            for i, o in enumerate(output):
                if torch.is_tensor(o) and torch.isnan(o).any():
                    print(f"NaN detected in {self.name}[{i}]")
                    # breakpoint()
        return output

# Available models: Wan-AI/Wan2.1-T2V-14B, Wan-AI/Wan2.1-T2V-1.3B
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
scheduler = UniPCMultistepScheduler(num_train_timesteps=1000, prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)

onload_device = torch.device("cuda")
offload_device = torch.device("cpu")

pipe.text_encoder.to(onload_device)
pipe.vae.to(onload_device)
apply_group_offloading(
    pipe.transformer,
    onload_device=onload_device,
    offload_device=offload_device,
    offload_type="leaf_level",
    use_stream=True,
)

for name, module in pipe.transformer.named_modules():
    registry = HookRegistry.check_if_exists_or_initialize(module)
    hook = NanDetectorHook(name)
    registry.register_hook(hook, "nan_detector")

prompt = "A cat walks on the grass, realistic"
negative_prompt = negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=832,
    num_frames=49,
    guidance_scale=5.0,
    num_inference_steps=30,
    generator=torch.Generator().manual_seed(42),
).frames[0]
export_to_video(output, "output3.mp4", fps=15)

🐛 explanation:

  • When using streams, we need to find out the layer execution order for prefetching weights on separate cuda stream
  • To find the layer execution order, we use LazyPrefetchGroupOffloadingHook, which is active during the first forward pass only
  • When using streams, we always set non_blocking parameter to True
  • For the first forward pass, every module group needs to onload itself
  • So, if non_blocking is True and every module is onloading itself (for the first forward pass), we need to either cuda sync after weight transfer, or do the weight transfer in blocking manner
  • If we don't wait for weight transfer to complete and computation kernel launches, the result computed is garbage
  • In this PR, we make sure to do the first forward pass in blocking manner, and only do non-blocking alternate-stream transfers once we know the prefetching order.

As a result, using group offloading alongside streams is slightly slower (depends on model size). For Wan, there seems to be an extra 3-5s overhead due to first forward pass being done in blocking manner.

Revealed with good old absolute rockstar pytorch profiler!

Note: this appears to be random and not reproducible on all GPUs for all height/width/num_frames and for all models 🫠

@Passenger12138 LMK if this fixes the issue you were facing.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

Failing tests are unrelated

@a-r-r-o-w a-r-r-o-w merged commit 3be6706 into main Mar 18, 2025
14 of 15 checks passed
@a-r-r-o-w a-r-r-o-w deleted the fix-group-offloading-leaf-stream branch March 18, 2025 09:14
sayakpaul added a commit that referenced this pull request Mar 20, 2025
Co-authored-by: SunMarc <marc.sun@hotmail.fr>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <hlky@hlky.ac>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>

making ```formatted_images``` initialization compact (#10801)

compact writing

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <dhruv.nair@gmail.com>

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

CogView4 Control Block (#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

LTX 0.9.5 (#10968)

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation
sayakpaul added a commit that referenced this pull request Apr 8, 2025
… offloading (#11081)

* implement record_stream for better performance.

* fix

* style.

* merge #11097

* Update src/diffusers/hooks/group_offloading.py

Co-authored-by: Aryan <aryan@huggingface.co>

* fixes

* docstring.

* remaining todos in low_cpu_mem_usage

* tests

* updates to docs.

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

WAN2.1 apply_group_offloading **ERROR** result
3 participants