Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix in LTXImageToVideoPipeline.prepare_latents() when latents is already set #10918

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

kakukakujirori
Copy link

What does this PR do?

A small bug is in LTXImageToVideoPipeline.prepare_latents() when latents is already set.
latents assumes five-dimensional input (batch, channel, num_frames, height, width) as we can see from the line

num_frames = (
    (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)

However, when latents is set in the argument, the code skips applying self._pack_latents().

Also, the shape of conditioning_mask is wrong.

This PR addresses these two issues.

"""Code snippet to see the error
"""

import torch
from diffusers import LTXImageToVideoPipeline

device = "cuda:0"

# instantiate a pipeline
pipe = LTXImageToVideoPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)

# create a dummy latents tensor
num_frames = 49
height = 352
width = 640

latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
latent_height = height // pipe.vae_spatial_compression_ratio
latent_width = width // pipe.vae_spatial_compression_ratio

latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)

# run
pipe(
    height=height,
    width=width,
    num_frames=num_frames,
    prompt="test_test",
    latents=latents,
)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 28
     25 latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)
     27 # run
---> 28 pipe(
     29     height=height,
     30     width=width,
     31     num_frames=num_frames,
     32     prompt="test_test",
     33     latents=latents,
     34 )

File ~/miniconda3/envs/py311/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/py311/lib/python3.11/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py:779, in LTXImageToVideoPipeline.__call__(self, image, prompt, negative_prompt, height, width, num_frames, frame_rate, num_inference_steps, timesteps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, decode_timestep, decode_noise_scale, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    777 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
    778 timestep = t.expand(latent_model_input.shape[0])
--> 779 timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
    781 noise_pred = self.transformer(
    782     hidden_states=latent_model_input,
    783     encoder_hidden_states=prompt_embeds,
   (...)
    791     return_dict=False,
    792 )[0]
    793 noise_pred = noise_pred.float()

RuntimeError: The size of tensor a (2) must match the size of tensor b (1540) at non-singleton dimension 1

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@a-r-r-o-w
Copy link
Member

I don't think there's a mistake with handling latents here. When user calls prepare_latents with their own latents, it is assumed to already be "prepared" (in this case, packed into ndim=3 tensor) and the only operation we wish to perform on the latent is device and dtype casting.

Regarding the mask shape, I believe that might be an actual mistake. Could you try running inference with only the mask_shape related change and passing ndim=3 latent?

@kakukakujirori
Copy link
Author

When user calls prepare_latents with their own latents, it is assumed to already be "prepared" (in this case, packed into ndim=3 tensor)

This case also fails. Since the packed latents is of shape (batch, num_patch, num_channel), the line

num_frames = (
    (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)

becomes equal to num_channel, which shouldn't be expected.

The following is the result, where

  • The input has been packed before being fed to the pipeline
  • latents packing is removed from prepare_latents()
"""Code snippet to see the error
"""

import torch
from diffusers import LTXImageToVideoPipeline

device = "cuda:0"

# instantiate a pipeline
pipe = LTXImageToVideoPipeline.from_pretrained(
    "a-r-r-o-w/LTX-Video-0.9.1-diffusers",
    torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload(device=device)

# create a dummy latents tensor
num_frames = 49
height = 352
width = 640

latent_num_frames = (num_frames - 1) // pipe.vae_temporal_compression_ratio + 1
latent_height = height // pipe.vae_spatial_compression_ratio
latent_width = width // pipe.vae_spatial_compression_ratio

latents = torch.randn((1, 128, latent_num_frames, latent_height, latent_width), device=device)

latents = pipe._pack_latents(latents, pipe.transformer_spatial_patch_size, pipe.transformer_temporal_patch_size)

# run
pipe(
    height=height,
    width=width,
    num_frames=num_frames,
    prompt="test_test",
    latents=latents,
)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 30
     27 latents = pipe._pack_latents(latents, pipe.transformer_spatial_patch_size, pipe.transformer_temporal_patch_size)
     29 # run
---> 30 pipe(
     31     height=height,
     32     width=width,
     33     num_frames=num_frames,
     34     prompt="test_test",
     35     latents=latents,
     36 )

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/pipelines/ltx/pipeline_ltx_image2video.py:784, in LTXImageToVideoPipeline.__call__(self, image, prompt, negative_prompt, height, width, num_frames, frame_rate, num_inference_steps, timesteps, guidance_scale, num_videos_per_prompt, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, decode_timestep, decode_noise_scale, output_type, return_dict, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length)
    781 timestep = t.expand(latent_model_input.shape[0])
    782 timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
--> 784 noise_pred = self.transformer(
    785     hidden_states=latent_model_input,
    786     encoder_hidden_states=prompt_embeds,
    787     timestep=timestep,
    788     encoder_attention_mask=prompt_attention_mask,
    789     num_frames=latent_num_frames,
    790     height=latent_height,
    791     width=latent_width,
    792     rope_interpolation_scale=rope_interpolation_scale,
    793     attention_kwargs=attention_kwargs,
    794     return_dict=False,
    795 )[0]
    796 noise_pred = noise_pred.float()
    798 if self.do_classifier_free_guidance:

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/models/transformers/transformer_ltx.py:440, in LTXVideoTransformer3DModel.forward(self, hidden_states, encoder_hidden_states, timestep, encoder_attention_mask, num_frames, height, width, rope_interpolation_scale, attention_kwargs, return_dict)
    430         hidden_states = torch.utils.checkpoint.checkpoint(
    431             create_custom_forward(block),
    432             hidden_states,
   (...)
    437             **ckpt_kwargs,
    438         )
    439     else:
--> 440         hidden_states = block(
    441             hidden_states=hidden_states,
    442             encoder_hidden_states=encoder_hidden_states,
    443             temb=temb,
    444             image_rotary_emb=image_rotary_emb,
    445             encoder_attention_mask=encoder_attention_mask,
    446         )
    448 scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
    449 shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/py312/lib/python3.12/site-packages/diffusers/models/transformers/transformer_ltx.py:245, in LTXVideoTransformerBlock.forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb, encoder_attention_mask)
    243 ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
    244 shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
--> 245 norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
    247 attn_hidden_states = self.attn1(
    248     hidden_states=norm_hidden_states,
    249     encoder_hidden_states=None,
    250     image_rotary_emb=image_rotary_emb,
    251 )
    252 hidden_states = hidden_states + attn_hidden_states * gate_msa

RuntimeError: The size of tensor a (1540) must match the size of tensor b (28160) at non-singleton dimension 1

@a-r-r-o-w
Copy link
Member

Ohh okay, I see! nice catch 🔥

cc @yiyixuxu What do we want to do here? Accept fully prepared latents from the user (ndim=3) and do a fix for that, or accept ndim=5 tensor and prepare it

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Mar 4, 2025

cc @yiyixuxu What do we want to do here? Accept fully prepared latents from the user (ndim=3) and do a fix for that, or accept ndim=5 tensor and prepare it

I think it should be fully prepared latents (output of prepare_latents) and do a fix for that

@kakukakujirori
Copy link
Author

Fixed. We can check the validity using the same code snippet above.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width)

if latents is not None:
conditioning_mask = latents.new_zeros(shape)
conditioning_mask = latents.new_zeros(mask_shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The normal route of prepare_latents() outputs conditioning_mask with that shape, so it is natural to align with it (here).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants