Skip to content

WAN2.1 apply_group_offloading **ERROR** result #11041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Passenger12138 opened this issue Mar 12, 2025 · 6 comments · Fixed by #11097
Closed

WAN2.1 apply_group_offloading **ERROR** result #11041

Passenger12138 opened this issue Mar 12, 2025 · 6 comments · Fixed by #11097
Assignees
Labels
bug Something isn't working

Comments

@Passenger12138
Copy link

Describe the bug

I am attempting to use the WAN 2.1 model from the diffusers library to complete an image-to-video task on an NVIDIA RTX 4090. To optimize memory usage, I chose the group offload method and intended to compare resource consumption across different configurations. However, during testing, I encountered two main issues:

  1. When using the group_offload_leaf_stream method:
    I received warnings that some layers were not executed during the forward pass:
It seems like some layers were not executed during the forward pass. This may lead to problems when applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please make sure that all layers are executed during the forward pass. The following layers were not executed:
unexecuted_layers=['blocks.25.attn2.norm_added_q', 'blocks.10.attn2.norm_added_q', 'blocks.13.attn2.norm_added_q', 'blocks.11.attn2.norm_added_q', 'blocks.34.attn2.norm_added_q', 'blocks.0.attn2.norm_added_q', 'blocks.35.attn2.norm_added_q', 'blocks.33.attn2.norm_added_q', 'blocks.21.attn2.norm_added_q', 'blocks.20.attn2.norm_added_q', 'blocks.3.attn2.norm_added_q', 'blocks.7.attn2.norm_added_q', 'blocks.22.attn2.norm_added_q', 'blocks.14.attn2.norm_added_q', 'blocks.29.attn2.norm_added_q', 'blocks.9.attn2.norm_added_q', 'blocks.1.attn2.norm_added_q', 'blocks.37.attn2.norm_added_q', 'blocks.18.attn2.norm_added_q', 'blocks.30.attn2.norm_added_q', 'blocks.4.attn2.norm_added_q', 'blocks.32.attn2.norm_added_q', 'blocks.36.attn2.norm_added_q', 'blocks.26.attn2.norm_added_q', 'blocks.6.attn2.norm_added_q', 'blocks.38.attn2.norm_added_q', 'blocks.17.attn2.norm_added_q', 'blocks.12.attn2.norm_added_q', 'blocks.19.attn2.norm_added_q', 'blocks.16.attn2.norm_added_q', 'blocks.15.attn2.norm_added_q', 'blocks.28.attn2.norm_added_q', 'blocks.24.attn2.norm_added_q', 'blocks.31.attn2.norm_added_q', 'blocks.8.attn2.norm_added_q', 'blocks.5.attn2.norm_added_q', 'blocks.27.attn2.norm_added_q', 'blocks.2.attn2.norm_added_q', 'blocks.39.attn2.norm_added_q', 'blocks.23.attn2.norm_added_q']

Image

This issue resulted in severe degradation of the generated output.
这是我选择的图像:
Image
我得到了错误的视频:
https://github.com/user-attachments/assets/7a8b55a2-6a71-493a-b7ae-64566b321954
当我使用默认pipe即不采用 group_offload_leaf_stream我得到了正确的结果:
https://github.com/user-attachments/assets/9b54c2f2-fa93-422f-b3df-619ee96bb3c8

2.When using the group_offload_block_1_stream method:
I encountered a runtime error: "RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same". It appears that the VAE module was not correctly assigned to the GPU device.

Traceback (most recent call last):
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 171, in <module>
    main(args)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 143, in main
    run_inference()
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 1188, in wrapper
    val = prof(func)(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 761, in f
    return func(*args, **kwds)
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 130, in run_inference
    output = pipe(
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 587, in __call__
    latents, condition = self.prepare_latents(
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 392, in prepare_latents
    latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 795, in encode
    h = self._encode(x)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 762, in _encode
    out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 439, in forward
    x = self.conv_in(x, feat_cache[idx])
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 78, in forward
    return super().forward(x)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 725, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 720, in _conv_forward
    return F.conv3d(
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Request for Help:

Are there recommended approaches to ensure all layers are properly executed, especially for the group_offload_leaf_stream method?
How can I resolve the device mismatch issue related to the VAE?
Any suggestions or guidance would be greatly appreciated!

Reproduction

here is my code

import argparse
import functools
import json
import os
import pathlib
import psutil
import time

import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from memory_profiler import profile
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline


def get_memory_usage():
    process = psutil.Process(os.getpid())
    mem_bytes = process.memory_info().rss
    return mem_bytes


@profile(precision=2)
def apply_offload(pipe: FluxPipeline, method: str) -> None:
    if method == "full_cuda":
        pipe.to("cuda")
    
    elif method == "model_offload":
        pipe.enable_model_cpu_offload()
    
    elif method == "sequential_offload":
        pipe.enable_sequential_cpu_offload()
    
    elif method == "group_offload_block_1":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="block_level",
            num_blocks_per_group=1,
            use_stream=False,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))

    elif method == "group_offload_leaf":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="leaf_level",
            use_stream=False,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))

    
    elif method == "group_offload_block_1_stream":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="block_level",
            num_blocks_per_group=1,
            use_stream=True,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
    
    elif method == "group_offload_leaf_stream":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="leaf_level",
            use_stream=True,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))


@profile(precision=2)
def load_pipeline():
    model_id = "Wan2.1-I2V-14B-480P-Diffusers"
    image_encoder = CLIPVisionModel.from_pretrained(
        model_id, subfolder="image_encoder", torch_dtype=torch.float32
    )
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)

    pipe = WanImageToVideoPipeline.from_pretrained(
        model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16, scheduler=scheduler_b
    )
    return pipe


@torch.no_grad()
def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(f"./results/check-wanmulti-framework/{args.method}/", exist_ok=True)
    pipe = load_pipeline()
    apply_offload(pipe, args.method)
    apply_offload_memory_usage = get_memory_usage()

    torch.cuda.reset_peak_memory_stats()
    cuda_model_memory = torch.cuda.max_memory_reserved()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    run_inference_memory_usage_list = []
    
    def cpu_mem_callback():
        nonlocal run_inference_memory_usage_list
        run_inference_memory_usage_list.append(get_memory_usage())

    @profile(precision=2)
    def run_inference():
        image = load_image("./dataset/character-img/imgs3/1.jpeg")
        max_area = 480 * 832
        aspect_ratio = image.height / image.width
        mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
        height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
        width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
        prompt = (
            "A person smile."
        )
        negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
        generator = torch.Generator("cuda").manual_seed(100)
        output = pipe(
            image=image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=81,
            guidance_scale=5.0,
            generator=generator,
        ).frames[0]
        export_to_video(output, f"./results/check-wanmulti-framework/{args.method}/wanx_diffusers.mp4", fps=16)

    t1 = time.time()
    run_inference()
    torch.cuda.synchronize()
    t2 = time.time()
    cuda_inference_memory = torch.cuda.max_memory_reserved()
    time_required = t2 - t1
    # run_inference_memory_usage = sum(run_inference_memory_usage_list) / len(run_inference_memory_usage_list)
    # print(f"Run inference memory usage list: {run_inference_memory_usage_list}")

    info = {
        "time": round(time_required, 2),
        "cuda_model_memory": round(cuda_model_memory / 1024**3, 2),
        "cuda_inference_memory": round(cuda_inference_memory / 1024**3, 2),
        "cpu_offload_memory": round(apply_offload_memory_usage / 1024**3, 2),
    }
    with open(output_dir / f"memory_usage_{args.method}.json", "w") as f:
        json.dump(info, f, indent=4)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="full_cuda", choices=["full_cuda", "model_offload", "sequential_offload", "group_offload_block_1", "group_offload_leaf", "group_offload_block_1_stream", "group_offload_leaf_stream"])
    parser.add_argument("--output_dir", type=str, default="./results/offload_profiling")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)

here is my environment

Package                           Version
--------------------------------- --------------------
absl-py                           2.1.0
accelerate                        1.4.0
addict                            2.4.0
aiofiles                          23.2.1
aiohappyeyeballs                  2.4.3
aiohttp                           3.10.10
aiosignal                         1.3.1
airportsdata                      20241001
albucore                          0.0.17
albumentations                    1.4.18
aliyun-python-sdk-core            2.16.0
aliyun-python-sdk-kms             2.16.5
altair                            5.4.1
annotated-types                   0.7.0
antlr4-python3-runtime            4.9.3
anyio                             4.6.2.post1
astor                             0.8.1
asttokens                         2.4.1
astunparse                        1.6.3
async-timeout                     4.0.3
attrs                             24.2.0
av                                13.1.0
beautifulsoup4                    4.12.3
blake3                            1.0.4
blinker                           1.9.0
boto3                             1.35.60
botocore                          1.35.60
braceexpand                       0.1.7
certifi                           2024.8.30
cffi                              1.17.1
charset-normalizer                3.4.0
click                             8.1.7
clip                              0.2.0
cloudpickle                       3.1.0
coloredlogs                       15.0.1
comm                              0.2.2
compressed-tensors                0.8.0
ConfigArgParse                    1.7
contourpy                         1.3.0
controlnet_aux                    0.0.7
cpm-kernels                       1.0.11
crcmod                            1.7
cryptography                      44.0.1
cupy-cuda12x                      13.3.0
cycler                            0.12.1
Cython                            3.0.12
dash                              2.18.2
dash-core-components              2.0.0
dash-html-components              2.0.0
dash-table                        5.0.0
dashscope                         1.22.2
datasets                          3.0.1
debugpy                           1.8.10
decorator                         4.4.2
decord                            0.6.0
deepspeed                         0.15.2
depyf                             0.18.0
diffsynth                         1.1.2
diffusers                         0.33.0.dev0
dill                              0.3.8
diskcache                         5.6.3
distro                            1.9.0
dnspython                         2.7.0
docker-pycreds                    0.4.0
easydict                          1.13
einops                            0.8.0
email_validator                   2.2.0
eval_type_backport                0.2.0
exceptiongroup                    1.2.2
executing                         2.1.0
facexlib                          0.3.0
fairscale                         0.4.13
fastapi                           0.115.2
fastjsonschema                    2.20.0
fastrlock                         0.8.3
ffmpy                             0.4.0
filelock                          3.16.1
filterpy                          1.4.5
flash-attn                        2.6.3
Flask                             3.0.3
flatbuffers                       24.3.25
fonttools                         4.54.1
frozenlist                        1.4.1
fsspec                            2024.6.1
ftfy                              6.3.0
func_timeout                      4.3.5
future                            1.0.0
fvcore                            0.1.5.post20221221
gast                              0.6.0
gguf                              0.10.0
gitdb                             4.0.11
GitPython                         3.1.43
google-pasta                      0.2.0
gradio                            5.5.0
gradio_client                     1.4.2
grpcio                            1.66.2
h11                               0.14.0
h5py                              3.12.1
hjson                             3.1.0
httpcore                          1.0.6
httptools                         0.6.4
httpx                             0.27.2
huggingface-hub                   0.29.1
humanfriendly                     10.0
idna                              3.10
imageio                           2.36.0
imageio-ffmpeg                    0.5.1
imgaug                            0.4.0
importlib_metadata                8.5.0
iniconfig                         2.0.0
interegular                       0.3.3
iopath                            0.1.10
ipykernel                         6.29.5
ipython                           8.29.0
ipywidgets                        8.1.5
itsdangerous                      2.2.0
jaxtyping                         0.2.34
jedi                              0.19.1
Jinja2                            3.1.4
jiter                             0.7.0
jmespath                          0.10.0
joblib                            1.4.2
jsonschema                        4.23.0
jsonschema-specifications         2024.10.1
jupyter_client                    8.6.3
jupyter_core                      5.7.2
jupyterlab_widgets                3.0.13
keras                             3.7.0
kiwisolver                        1.4.7
lark                              1.2.2
lazy_loader                       0.4
libclang                          18.1.1
libigl                            2.5.1
linkify-it-py                     2.0.3
llvmlite                          0.43.0
lm-format-enforcer                0.10.9
lmdb                              1.6.2
loguru                            0.7.3
lvis                              0.5.3
Markdown                          3.7
markdown-it-py                    2.2.0
MarkupSafe                        2.1.5
matplotlib                        3.9.2
matplotlib-inline                 0.1.7
mdit-py-plugins                   0.3.3
mdurl                             0.1.2
memory-profiler                   0.61.0
mistral_common                    1.5.1
ml-dtypes                         0.4.1
modelscope                        1.23.2
moviepy                           1.0.3
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.18.6
multidict                         6.1.0
multiprocess                      0.70.16
namex                             0.0.8
narwhals                          1.10.0
natsort                           8.4.0
nbformat                          5.10.4
nest-asyncio                      1.6.0
networkx                          3.4.1
ninja                             1.11.1.3
numba                             0.60.0
numpy                             1.26.4
nvdiffrast                        0.3.3
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.560.30
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
onnxruntime                       1.20.0
open3d                            0.18.0
openai                            1.54.4
openai-clip                       1.0.1
opencv-python                     4.10.0.84
opencv-python-headless            4.10.0.84
opt_einsum                        3.4.0
optree                            0.13.1
orjson                            3.10.7
oss2                              2.19.1
outlines                          0.0.46
packaging                         24.1
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post4
peft                              0.13.2
pexpect                           4.9.0
pillow                            10.4.0
pip                               24.2
platformdirs                      4.3.6
plotly                            5.24.1
pluggy                            1.5.0
pooch                             1.8.2
portalocker                       2.10.1
proglog                           0.1.10
prometheus_client                 0.21.0
prometheus-fastapi-instrumentator 7.0.0
prompt_toolkit                    3.0.48
propcache                         0.2.0
protobuf                          5.28.2
psutil                            6.0.0
ptyprocess                        0.7.0
pudb                              2024.1.2
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
pyairports                        2.1.1
pyarrow                           17.0.0
pybind11                          2.13.6
pycocoevalcap                     1.2
pycocotools                       2.0.8
pycountry                         24.6.1
pycparser                         2.22
pycryptodome                      3.21.0
pydantic                          2.9.2
pydantic_core                     2.23.4
pydub                             0.25.1
Pygments                          2.18.0
pyiqa                             0.1.10
PyMatting                         1.1.12
PyMCubes                          0.1.6
pyparsing                         3.2.0
pyquaternion                      0.9.9
pytest                            8.3.4
python-dateutil                   2.9.0.post0
python-dotenv                     1.0.1
python-multipart                  0.0.12
pytorch3d                         0.7.8
pytz                              2024.2
PyYAML                            6.0.2
pyzmq                             26.2.0
qwen-vl-utils                     0.0.10
ray                               2.37.0
referencing                       0.35.1
regex                             2024.9.11
rembg                             2.0.59
requests                          2.32.3
requests-toolbelt                 1.0.0
retrying                          1.3.4
rich                              13.9.2
rpds-py                           0.20.0
ruff                              0.6.9
s3transfer                        0.10.3
safehttpx                         0.1.1
safetensors                       0.4.5
scikit-image                      0.24.0
scikit-learn                      1.5.2
scikit-video                      1.1.11
scipy                             1.14.1
semantic-version                  2.10.0
sentencepiece                     0.2.0
sentry-sdk                        2.18.0
setproctitle                      1.3.3
setuptools                        75.2.0
shapely                           2.0.7
shellingham                       1.5.4
six                               1.16.0
sk-video                          1.1.10
smmap                             5.0.1
sniffio                           1.3.1
soupsieve                         2.6
stack-data                        0.6.3
starlette                         0.40.0
SwissArmyTransformer              0.4.12
sympy                             1.13.1
tabulate                          0.9.0
tenacity                          9.0.0
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensorflow-io-gcs-filesystem      0.37.1
termcolor                         2.5.0
thop                              0.1.1.post2209072238
threadpoolctl                     3.5.0
tifffile                          2024.9.20
tiktoken                          0.7.0
timm                              1.0.11
tokenizers                        0.20.3
tomesd                            0.1.3
tomli                             2.2.1
tomlkit                           0.12.0
torch                             2.6.0
torchaudio                        2.6.0
torchdiffeq                       0.2.4
torchsde                          0.2.6
torchvision                       0.21.0
tornado                           6.4.2
tqdm                              4.66.5
traitlets                         5.14.3
trampoline                        0.1.2
transformers                      4.46.2
transformers-stream-generator     0.0.4
trimesh                           4.5.2
triton                            3.2.0
typeguard                         2.13.3
typer                             0.12.5
typing_extensions                 4.12.2
tzdata                            2024.2
uc-micro-py                       1.0.3
urllib3                           2.2.3
urwid                             2.6.16
urwid_readline                    0.15.1
uvicorn                           0.32.0
uvloop                            0.21.0
wandb                             0.18.7
watchfiles                        0.24.0
wcwidth                           0.2.13
webdataset                        0.2.100
websocket-client                  1.8.0
websockets                        12.0
Werkzeug                          3.0.4
wheel                             0.44.0
widgetsnbextension                4.0.13
wrapt                             1.17.0
xatlas                            0.0.9
xxhash                            3.5.0
yacs                              0.1.8
yapf                              0.43.0
yarl                              1.15.3
zipp                              3.20.2

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.1
  • Transformers version: 4.46.2
  • Accelerate version: 1.4.0
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA A800-SXM4-80GB, 81251 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@DN6 @a-r-r-o-w

@Passenger12138 Passenger12138 added the bug Something isn't working label Mar 12, 2025
@DN6
Copy link
Collaborator

DN6 commented Mar 12, 2025

cc @a-r-r-o-w

@a-r-r-o-w a-r-r-o-w self-assigned this Mar 12, 2025
@a-r-r-o-w
Copy link
Member

Thanks for the detailed issue!

When using the group_offload_leaf_stream method: I received warnings that some layers were not executed during the forward pass:

This warning can be ignored for now. The model weights contain additional norm_added_q layers that were a mistake on our part when doing the original->diffusers model-format conversion.

The reason for quality degradation seems to be coming from applying group offloading to the text encoder of Wan. Maybe the UMT5EncoderModel implementation has a layer invocation order that is not compatible with streamed group offloading -- I will have to look more deeply.

Could you try applying it only to the transformer and reporting your results?

I don't fully understand the VAE issue here yet but can take a look soon.

Are there recommended approaches to ensure all layers are properly executed, especially for the group_offload_leaf_stream method?

The layer invocation order is automatically detected, so if there are any problems with our transformer implementation, we will have to make improvements to either group offloading code to detect this or rewrite parts of the modeling as necessary. My guess is the text encoder is the reason for poor results, since I was able to run just the transformer with group offloading a few days ago producing good results.

@Passenger12138
Copy link
Author

Passenger12138 commented Mar 17, 2025

cc @a-r-r-o-w
Thank you for your suggestion, but I still got the wrong result. Here is the video I received

wanx_diffusers.mp4

and my command line output

Image

I did not receive warnings that

 some layers were not executed during the forward pass:

Perhaps you have updated diffusers

In order to run on the 4090, I manually placed image_decoder, text_decoder, and transformer on the CPU during VAE decoding。

@a-r-r-o-w
Copy link
Member

Thank you for testing! I'm looking into it now, and apologies for the delay/issues faced.

@Passenger12138
Copy link
Author

I have completed testing all the group offload methods.

I compared the following combinations:

  1. apply_group_offload --- group_offload_block_1
  2. apply_group_offload --- group_offload_leaf
  3. apply_group_offload --- group_offload_block_1_stream
  4. apply_group_offload --- group_offload_leaf_stream

Out of these, I observed that only apply_group_offload --- group_offload_leaf_stream produced incorrect results. You can find the error here:

423298751-ea61d0cf-9cbb-477c-8dc0-510ce5e377b4.mp4

The other three methods worked correctly on my machine, as shown in the following output:

wanx_diffusers.mp4

Could you kindly focus on fixing the group_offload_leaf_stream method?

Additionally, I noticed that when using the diffusers code, the results are significantly lower compared to those from the official Wan2.1 repository. I have set the scheduler by

flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
    scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)
    pipe = WanImageToVideoPipeline.from_pretrained(
        model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16, scheduler=scheduler
    )

After reviewing the code, I suspect that the issue might stem from the model being loaded with the transformer3d weights in bf16 format, which could potentially lead to a loss of precision. I would like to maintain the original precision to achieve better results. If you have any suggestions on how to modify the code to address this issue, I would greatly appreciate it.

Thank you for your attention!

@a-r-r-o-w
Copy link
Member

Could you kindly focus on fixing the group_offload_leaf_stream method?

Additionally, I noticed that when using the diffusers code, the results are significantly lower > compared to those from the official Wan2.1 repository. I have set the scheduler by

Please take a look at #11097 when you get some time. The bug is not reproducible on all GPUs, for all height/width/num_frames, or for all models and only seems to occur in certain cases. This made it extremely hard to debug but I believe it should be fixed now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants