From 4dd739fe558304a1708f9d039fd6965e943edcba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 7 May 2025 21:53:46 +0300 Subject: [PATCH 001/264] Add SkyReels-V2 pipelines for text-to-video, image-to-video, and diffusion forcing - Introduced the drafts of `SkyReelsV2TextToVideoPipeline`, `SkyReelsV2ImageToVideoPipeline`, `SkyReelsV2DiffusionForcingPipeline`, and `FlowUniPCMultistepScheduler`. --- .../pipelines/skyreels_v2/__init__.py | 53 ++ .../pipeline_skyreels_v2_diffusion_forcing.py | 710 +++++++++++++++++ .../pipeline_skyreels_v2_image_to_video.py | 589 ++++++++++++++ .../pipeline_skyreels_v2_text_to_video.py | 537 +++++++++++++ src/diffusers/schedulers/__init__.py | 2 + .../scheduling_flow_unipc_multistep.py | 721 ++++++++++++++++++ 6 files changed, 2612 insertions(+) create mode 100644 src/diffusers/pipelines/skyreels_v2/__init__.py create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py create mode 100644 src/diffusers/schedulers/scheduling_flow_unipc_multistep.py diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py new file mode 100644 index 000000000000..62a8f98feeb9 --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -0,0 +1,53 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_skyreels_v2_text_to_video"] = ["SkyReelsV2TextToVideoPipeline"] + _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] + _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline + from .pipeline_skyreels_v2_image_to_video import SkyReelsV2ImageToVideoPipeline + from .pipeline_skyreels_v2_text_to_video import SkyReelsV2TextToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + del _dummy_objects diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py new file mode 100644 index 000000000000..8757e4fba47c --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -0,0 +1,710 @@ +# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowUniPCMultistepScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_skyreels_v2_text_to_video import SkyReelsV2PipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2DiffusionForcingPipeline + + >>> # Load the pipeline + >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Prepare conditioning frames (list of PIL Images or a tensor of shape [frames, height, width, channels]) + >>> frames = [PIL.Image.open(f"frame_{i}.jpg").convert("RGB") for i in range(5)] + >>> # Create mask: 1 for conditioning frames, 0 for frames to generate + >>> mask = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] + >>> prompt = "A person walking in the park" + >>> video = pipe(prompt, conditioning_frames=frames, conditioning_frame_mask=mask, num_frames=16).frames[0] + ``` +""" + + +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline): + """ + Pipeline for video generation with diffusion forcing (conditioning on specific frames) using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`SkyReelsV2TransformerModel`]): + A SkyReels-V2 transformer model for diffusion with diffusion forcing capability. + scheduler ([`FlowUniPCMultistepScheduler`]): + A scheduler to be used in combination with the transformer to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLWan, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: WanTransformer3DModel, + scheduler: FlowUniPCMultistepScheduler, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + device: (`torch.device`): + The torch device to place the resulting embeddings on. + num_videos_per_prompt (`int`): + The number of videos that should be generated per prompt. + do_classifier_free_guidance (`bool`): + Whether to use classifier-free guidance or not. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than 1). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizer parameters + if max_sequence_length is None: + max_sequence_length = self.tokenizer.model_max_length + + # Get prompt text embeddings + if prompt_embeds is None: + # Text encoder expects tokens to be of shape (batch_size, context_length) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds[0] + + # Duplicate prompt embeddings for each generation per prompt + if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Get negative prompt embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" + ) + + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_input_ids = negative_text_inputs.input_ids + negative_attention_mask = negative_text_inputs.attention_mask + + negative_prompt_embeds = self.text_encoder( + negative_input_ids.to(device), + attention_mask=negative_attention_mask.to(device), + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # Duplicate negative prompt embeddings for each generation per prompt + if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # For classifier-free guidance, combine embeddings + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Decode the generated latent sample using the VAE to produce video frames. + + Args: + latents (`torch.Tensor`): Generated latent samples from the diffusion process. + + Returns: + `torch.Tensor`: Decoded video frames. + """ + video_length = latents.shape[2] + + latents = 1 / self.vae.config.scaling_factor * latents # scale latents + + # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] + ) + + # Decode all frames + video = self.vae.decode(latents).sample + + # Reshape back to [batch, frames, channels, height, width] + video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) + + # Rescale video from [-1, 1] to [0, 1] + video = (video / 2 + 0.5).clamp(0, 1) + + # Rescale to pixel values + video = (video * 255).to(torch.uint8) + + # Permute channels to [batch, frames, height, width, channels] + return video.permute(0, 1, 3, 4, 2) + + def encode_frames(self, frames: Union[List[PIL.Image.Image], torch.Tensor, np.ndarray]) -> torch.Tensor: + """ + Encode the conditioning frames to latent space using VAE. + + Args: + frames (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`): + List of frames or tensor/array containing frames to encode. + + Returns: + `torch.Tensor`: Latent representation of the input frames. + """ + device = self._execution_device + dtype = self.vae.dtype + + if isinstance(frames, list): + # Convert list of PIL Images to tensor [frames, channels, height, width] + processed_frames = [] + for frame in frames: + if isinstance(frame, PIL.Image.Image): + frame = np.array(frame).astype(np.float32) / 127.5 - 1.0 + frame = torch.from_numpy(frame).permute(2, 0, 1) + processed_frames.append(frame) + frames_tensor = torch.stack(processed_frames) + + elif isinstance(frames, np.ndarray): + # Convert numpy array to tensor + if frames.ndim == 4: # [frames, height, width, channels] + frames = frames.astype(np.float32) / 127.5 - 1.0 + frames_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2) # [frames, channels, height, width] + else: + raise ValueError( + f"Unexpected numpy array shape: {frames.shape}, expected [frames, height, width, channels]" + ) + + elif isinstance(frames, torch.Tensor): + if frames.ndim == 4: + if frames.shape[1] == 3: # [frames, channels, height, width] + frames_tensor = frames + elif frames.shape[3] == 3: # [frames, height, width, channels] + frames_tensor = frames.permute(0, 3, 1, 2) + else: + raise ValueError(f"Unexpected tensor shape: {frames.shape}, cannot determine channel dimension") + else: + raise ValueError(f"Unexpected tensor shape: {frames.shape}, expected 4D tensor") + + # Ensure pixel values are in range [-1, 1] + if frames_tensor.min() >= 0 and frames_tensor.max() <= 1: + frames_tensor = 2.0 * frames_tensor - 1.0 + elif frames_tensor.min() >= 0 and frames_tensor.max() <= 255: + frames_tensor = frames_tensor / 127.5 - 1.0 + else: + raise ValueError(f"Unsupported frame input type: {type(frames)}") + + # Move to device and correct dtype + frames_tensor = frames_tensor.to(device=device, dtype=dtype) + + # Process in batches if there are many frames, to avoid OOM + batch_size = 8 # reasonable batch size for VAE encoding + latents = [] + + for i in range(0, frames_tensor.shape[0], batch_size): + batch = frames_tensor[i : i + batch_size] + with torch.no_grad(): + batch_latents = self.vae.encode(batch).latent_dist.sample() + batch_latents = batch_latents * self.vae.config.scaling_factor + latents.append(batch_latents) + + # Concatenate all batches + latents = torch.cat(latents, dim=0) + + return latents + + def prepare_latents_with_forcing( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + conditioning_latents: Optional[torch.Tensor] = None, + conditioning_frame_mask: Optional[List[int]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare latent variables for diffusion forcing. + + Args: + batch_size (`int`): Number of samples to generate. + num_channels_latents (`int`): Number of channels in the latent space. + num_frames (`int`): Number of video frames to generate. + height (`int`): Height of the generated images in pixels. + width (`int`): Width of the generated images in pixels. + dtype (`torch.dtype`): Data type of the latent variables. + device (`torch.device`): Device to generate the latents on. + generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. + conditioning_latents (`torch.Tensor`, *optional*): Latent representations of conditioning frames. + conditioning_frame_mask (`List[int]`, *optional*): + Binary mask indicating which frames are conditioning frames. + + Returns: + `Tuple[torch.Tensor, torch.Tensor]`: Prepared initial latent variables and forcing frame mask. + """ + # Check if we have all required inputs for diffusion forcing + if conditioning_frame_mask is None: + raise ValueError("conditioning_frame_mask is required for diffusion forcing") + + if conditioning_latents is None: + raise ValueError("conditioning_latents are required for diffusion forcing") + + # Ensure mask has the right length + if len(conditioning_frame_mask) != num_frames: + raise ValueError( + f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must match num_frames ({num_frames})" + ) + + # Count conditioning frames in the mask + num_cond_frames = sum(conditioning_frame_mask) + + # Check if conditioning_latents has correct number of frames + if conditioning_latents.shape[0] != num_cond_frames: + raise ValueError( + f"Number of conditioning frames ({conditioning_latents.shape[0]}) must match " + f"number of 1s in conditioning_frame_mask ({num_cond_frames})" + ) + + # Shape for full video latents + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + ) + + # Generate or use provided latents + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale initial noise by the standard deviation + latents = latents * self.scheduler.init_noise_sigma + + # Create forcing mask tensor [batch, 1, frames, 1, 1] + forcing_mask = torch.tensor(conditioning_frame_mask, device=device, dtype=dtype) + forcing_mask = forcing_mask.view(1, 1, num_frames, 1, 1).expand(batch_size, 1, -1, 1, 1) + + # Insert conditioning latents at the correct positions based on mask + cond_idx = 0 + for frame_idx, is_cond in enumerate(conditioning_frame_mask): + if is_cond: + # Replace the random noise with the encoded conditioning frame + latents[:, :, frame_idx : frame_idx + 1] = ( + conditioning_latents[cond_idx : cond_idx + 1].unsqueeze(0).expand(batch_size, -1, -1, -1, -1) + ) + cond_idx += 1 + + return latents, forcing_mask + + def check_conditioning_inputs( + self, + conditioning_frames: Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], + conditioning_frame_mask: List[int], + num_frames: int, + ): + """Check validity of conditioning inputs.""" + # Validate mask length + if len(conditioning_frame_mask) != num_frames: + raise ValueError( + f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must match num_frames ({num_frames})" + ) + + # Validate mask values + if not all(x in [0, 1] for x in conditioning_frame_mask): + raise ValueError("conditioning_frame_mask must only contain 0s and 1s") + + # Count conditioning frames + num_conditioning_frames = sum(conditioning_frame_mask) + + # Validate number of conditioning frames + if isinstance(conditioning_frames, list): + if len(conditioning_frames) != num_conditioning_frames: + raise ValueError( + f"Number of conditioning frames ({len(conditioning_frames)}) must match " + f"number of 1s in conditioning_frame_mask ({num_conditioning_frames})" + ) + elif isinstance(conditioning_frames, (torch.Tensor, np.ndarray)): + if conditioning_frames.shape[0] != num_conditioning_frames: + raise ValueError( + f"Number of conditioning frames ({conditioning_frames.shape[0]}) must match " + f"number of 1s in conditioning_frame_mask ({num_conditioning_frames})" + ) + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor, np.ndarray]] = None, + conditioning_frame_mask: Optional[List[int]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 25, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + custom_shift: Optional[float] = None, + ) -> Union[SkyReelsV2PipelineOutput, Tuple]: + """ + The call function to the pipeline for generation with diffusion forcing. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + conditioning_frames (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`, *optional*): + Frames to use as conditioning points during video generation. Should be provided as a list of PIL + images, or as a tensor/array of shape [num_cond_frames, height, width, channels] or [num_cond_frames, + channels, height, width]. + conditioning_frame_mask (`List[int]`, *optional*): + Binary mask indicating which frames are conditioning frames (1) and which are to be generated (0). Must + have the same length as num_frames and the same number of 1s as the number of conditioning_frames. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames to generate. + height (`int`, *optional*, defaults to None): + The height in pixels of the generated video frames. If not provided, height is automatically determined + from the model configuration. + width (`int`, *optional*, defaults to None): + The width in pixels of the generated video frames. If not provided, width is automatically determined + from the model configuration. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + output_type (`str`, *optional*, defaults to `"tensor"`): + The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or + `numpy.array` output respectively. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.SkyReelsV2PipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + custom_shift (`float`, *optional*): + Custom shifting factor to use in the flow matching framework. + + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2DiffusionForcingPipeline + + >>> # Load the pipeline + >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # Prepare conditioning frames (list of PIL Images or a tensor of shape [frames, height, width, channels]) + >>> frames = [PIL.Image.open(f"frame_{i}.jpg").convert("RGB") for i in range(5)] + >>> # Create mask: 1 for conditioning frames, 0 for frames to generate + >>> mask = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] + >>> prompt = "A person walking in the park" + >>> video = pipe(prompt, conditioning_frames=frames, conditioning_frame_mask=mask, num_frames=16).frames[0] + ``` + + Returns: + [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is + returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to transformer dimensions + height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 + width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 + + # 1. Check inputs + self.check_inputs( + prompt, + num_frames, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # Check diffusion forcing inputs + if conditioning_frames is None or conditioning_frame_mask is None: + raise ValueError("For diffusion forcing, conditioning_frames and conditioning_frame_mask must be provided") + + self.check_conditioning_inputs(conditioning_frames, conditioning_frame_mask, num_frames) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + # 3. Determine whether to apply classifier-free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + # 5. Encode conditioning frames + conditioning_latents = self.encode_frames(conditioning_frames) + + # 6. Prepare timesteps + timestep_shift = None if custom_shift is None else {"shift": custom_shift} + self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables with forcing + num_channels_latents = self.vae.config.latent_channels + latents, forcing_mask = self.prepare_latents_with_forcing( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + conditioning_latents, + conditioning_frame_mask, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand latents for classifier-free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # Scale model input + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict the noise residual using Diffusion Forcing + # Use standard forward pass; forcing logic is applied outside the model + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # Perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Update latents with the scheduler step + latents_input = latents + latents_updated = self.scheduler.step(noise_pred, t, latents_input).prev_sample + + # Apply forcing: use original latents for conditioning frames, updated latents for frames to generate + # forcing_mask is 1 for conditioning frames, 0 for frames to generate + latents = torch.where(forcing_mask, latents_input, latents_updated) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing: decode latents + video = self.decode_latents(latents) + + # 10. Convert output format + if output_type == "np": + video = video.cpu().numpy() + elif output_type == "tensor": + video = video.cpu() + + # 11. Offload all models + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py new file mode 100644 index 000000000000..8844cc795ca1 --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -0,0 +1,589 @@ +# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowUniPCMultistepScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_skyreels_v2_text_to_video import SkyReelsV2PipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2ImageToVideoPipeline + + >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> image = PIL.Image.open("input_image.jpg").convert("RGB") + >>> prompt = "A beautiful view of mountains" + >>> video_frames = pipe(prompt, image=image, num_frames=16).frames[0] + ``` +""" + + +class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline): + """ + Pipeline for image-to-video generation using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`SkyReelsV2TransformerModel`]): + A SkyReels-V2 transformer model for diffusion. + scheduler ([`FlowUniPCMultistepScheduler`]): + A scheduler to be used in combination with the transformer to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLWan, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: WanTransformer3DModel, + scheduler: FlowUniPCMultistepScheduler, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + device: (`torch.device`): + The torch device to place the resulting embeddings on. + num_videos_per_prompt (`int`): + The number of videos that should be generated per prompt. + do_classifier_free_guidance (`bool`): + Whether to use classifier-free guidance or not. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than 1). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizer parameters + if max_sequence_length is None: + max_sequence_length = self.tokenizer.model_max_length + + # Get prompt text embeddings + if prompt_embeds is None: + # Text encoder expects tokens to be of shape (batch_size, context_length) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds[0] + + # Duplicate prompt embeddings for each generation per prompt + if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Get negative prompt embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" + ) + + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_input_ids = negative_text_inputs.input_ids + negative_attention_mask = negative_text_inputs.attention_mask + + negative_prompt_embeds = self.text_encoder( + negative_input_ids.to(device), + attention_mask=negative_attention_mask.to(device), + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # Duplicate negative prompt embeddings for each generation per prompt + if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # For classifier-free guidance, combine embeddings + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Decode the generated latent sample using the VAE to produce video frames. + + Args: + latents (`torch.Tensor`): Generated latent samples from the diffusion process. + + Returns: + `torch.Tensor`: Decoded video frames. + """ + video_length = latents.shape[2] + + latents = 1 / self.vae.config.scaling_factor * latents # scale latents + + # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] + ) + + # Decode all frames + video = self.vae.decode(latents).sample + + # Reshape back to [batch, frames, channels, height, width] + video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) + + # Rescale video from [-1, 1] to [0, 1] + video = (video / 2 + 0.5).clamp(0, 1) + + # Rescale to pixel values + video = (video * 255).to(torch.uint8) + + # Permute channels to [batch, frames, height, width, channels] + return video.permute(0, 1, 3, 4, 2) + + def encode_image(self, image: Union[torch.Tensor, PIL.Image.Image]) -> torch.Tensor: + """ + Encode the input image to latent space using VAE. + + Args: + image (`torch.Tensor` or `PIL.Image.Image`): Input image to encode. + + Returns: + `torch.Tensor`: Latent representation of the input image. + """ + if isinstance(image, PIL.Image.Image): + # Convert PIL image to tensor + image = torch.from_numpy(np.array(image)).float() / 127.5 - 1.0 + image = image.permute(2, 0, 1).unsqueeze(0) + elif isinstance(image, torch.Tensor) and image.ndim == 3: + # Add batch dimension for single image tensor + image = image.unsqueeze(0) + elif isinstance(image, torch.Tensor) and image.ndim == 4: + # Ensure input is in -1 to 1 range + if image.min() >= 0 and image.max() <= 1: + image = 2.0 * image - 1.0 + else: + raise ValueError(f"Invalid image input type: {type(image)}") + + image = image.to(device=self._execution_device, dtype=self.vae.dtype) + + # Encode the image using VAE + latents = self.vae.encode(image).latent_dist.sample() + latents = latents * self.vae.config.scaling_factor + + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + image_latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare latent variables from noise for the diffusion process, optionally incorporating image conditioning. + + Args: + batch_size (`int`): Number of samples to generate. + num_channels_latents (`int`): Number of channels in the latent space. + num_frames (`int`): Number of video frames to generate. + height (`int`): Height of the generated images in pixels. + width (`int`): Width of the generated images in pixels. + dtype (`torch.dtype`): Data type of the latent variables. + device (`torch.device`): Device to generate the latents on. + generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. + image_latents (`torch.Tensor`, *optional*): Latent representation of the conditioning image. + + Returns: + `torch.Tensor`: Prepared initial latent variables. + """ + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale initial noise by the standard deviation + latents = latents * self.scheduler.init_noise_sigma + + # If we have image conditioning, incorporate it into the first frame + if image_latents is not None: + # Expand image latents to match the number of frames by repeating along frame dimension + # This helps provide a stronger image signal throughout the video + image_latents = image_latents.unsqueeze(2) + first_frame_latents = image_latents.expand(-1, -1, 1, -1, -1) + + # Create a stronger conditioning for the first frame + # This helps ensure the video starts with the input image + alpha = 0.8 # Higher alpha means stronger image conditioning + latents[:, :, 0:1] = alpha * first_frame_latents + (1 - alpha) * latents[:, :, 0:1] + + return latents + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 25, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + custom_shift: Optional[float] = None, + ) -> Union[SkyReelsV2PipelineOutput, Tuple]: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + The image to use as the starting point for the video generation. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames to generate. + height (`int`, *optional*, defaults to None): + The height in pixels of the generated video frames. If not provided, height is automatically determined + from the model configuration. + width (`int`, *optional*, defaults to None): + The width in pixels of the generated video frames. If not provided, width is automatically determined + from the model configuration. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + output_type (`str`, *optional*, defaults to `"tensor"`): + The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or + `numpy.array` output respectively. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + custom_shift (`float`, *optional*): + Custom shifting factor to use in the flow matching framework. + + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2ImageToVideoPipeline + + >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> image = PIL.Image.open("input_image.jpg").convert("RGB") + >>> prompt = "A beautiful view of mountains" + >>> video_frames = pipe(prompt, image=image, num_frames=16).frames[0] + ``` + + Returns: + [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is + returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to transformer dimensions + height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 + width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 + + # 1. Check inputs + self.check_inputs( + prompt, + num_frames, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("For image-to-video generation, an input image is required.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + # 3. Determine whether to apply classifier-free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + # 5. Encode the input image + image_latents = self.encode_image(image) + + # Duplicate image latents for each batch and prompt + if isinstance(image, PIL.Image.Image) or (isinstance(image, torch.Tensor) and image.ndim < 4): + # For a single image to be duplicated + image_latents = image_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1) + + # 6. Prepare timesteps + timestep_shift = None if custom_shift is None else {"shift": custom_shift} + self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image_latents, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand latents for classifier-free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # Scale model input + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=None, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # Perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Update latents with the scheduler step + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing: decode latents + video = self.decode_latents(latents) + + # 10. Convert output format + if output_type == "np": + video = video.cpu().numpy() + elif output_type == "tensor": + video = video.cpu() + + # 11. Offload all models + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py new file mode 100644 index 000000000000..03dddccdd43b --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py @@ -0,0 +1,537 @@ +# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer + +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowUniPCMultistepScheduler +from ...utils import ( + BaseOutput, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SkyReelsV2TextToVideoPipeline + + >>> pipe = SkyReelsV2TextToVideoPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A panda eating bamboo on a rock, 4k, detailed" + >>> video_frames = pipe(prompt, num_frames=16).frames[0] + ``` +""" + + +@dataclass +class SkyReelsV2PipelineOutput(BaseOutput): + """ + Output class for SkyReels-V2 pipelines. + + Args: + frames (`List[np.ndarray]` or `torch.Tensor`): + List of video frames generated by the pipeline. If the pipeline is running on GPU, the output is a + `torch.Tensor` of shape `(n_output_videos, n_frames, height, width, num_channels)`. Otherwise, the output + is a list of numpy arrays:`(n_output_videos, n_frames, height, width, num_channels)` with values in [0, + 255]. + """ + + frames: Union[List[np.ndarray], torch.Tensor] + + +class SkyReelsV2TextToVideoPipeline(DiffusionPipeline): + """ + Pipeline for text-to-video generation using SkyReels-V2. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + transformer ([`WanTransformer3DModel`]): + A SkyReels-V2 transformer model for diffusion. + scheduler ([`FlowUniPCMultistepScheduler`]): + A scheduler to be used in combination with the transformer to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKLWan, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + transformer: WanTransformer3DModel, + scheduler: FlowUniPCMultistepScheduler, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide image generation. + device: (`torch.device`): + The torch device to place the resulting embeddings on. + num_videos_per_prompt (`int`): + The number of videos that should be generated per prompt. + do_classifier_free_guidance (`bool`): + Whether to use classifier-free guidance or not. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than 1). + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizer parameters + if max_sequence_length is None: + max_sequence_length = self.tokenizer.model_max_length + + # Get prompt text embeddings + if prompt_embeds is None: + # Text encoder expects tokens to be of shape (batch_size, context_length) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask.to(device), + ) + prompt_embeds = prompt_embeds[0] + + # Duplicate prompt embeddings for each generation per prompt + if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Get negative prompt embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + elif batch_size != len(negative_prompt): + raise ValueError( + f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" + ) + + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_input_ids = negative_text_inputs.input_ids + negative_attention_mask = negative_text_inputs.attention_mask + + negative_prompt_embeds = self.text_encoder( + negative_input_ids.to(device), + attention_mask=negative_attention_mask.to(device), + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + # Duplicate negative prompt embeddings for each generation per prompt + if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # For classifier-free guidance, combine embeddings + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + """ + Decode the generated latent sample using the VAE to produce video frames. + + Args: + latents (`torch.Tensor`): Generated latent samples from the diffusion process. + + Returns: + `torch.Tensor`: Decoded video frames. + """ + video_length = latents.shape[2] + + latents = 1 / self.vae.config.scaling_factor * latents # scale latents + + # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] + latents = latents.permute(0, 2, 1, 3, 4).reshape( + latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] + ) + + # Decode all frames + video = self.vae.decode(latents).sample + + # Reshape back to [batch, frames, channels, height, width] + video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) + + # Rescale video from [-1, 1] to [0, 1] + video = (video / 2 + 0.5).clamp(0, 1) + + # Rescale to pixel values + video = (video * 255).to(torch.uint8) + + # Permute channels to [batch, frames, height, width, channels] + return video.permute(0, 1, 3, 4, 2) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + num_frames: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare latent variables from noise for the diffusion process. + + Args: + batch_size (`int`): Number of samples to generate. + num_channels_latents (`int`): Number of channels in the latent space. + num_frames (`int`): Number of video frames to generate. + height (`int`): Height of the generated images in pixels. + width (`int`): Width of the generated images in pixels. + dtype (`torch.dtype`): Data type of the latent variables. + device (`torch.device`): Device to generate the latents on. + generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. + + Returns: + `torch.Tensor`: Prepared initial latent variables. + """ + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # Scale initial noise by the standard deviation + latents = latents * self.scheduler.init_noise_sigma + return latents + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_frames: int = 16, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 25, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + custom_shift: Optional[float] = None, + ) -> Union[SkyReelsV2PipelineOutput, Tuple]: + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + num_frames (`int`, *optional*, defaults to 16): + The number of video frames to generate. + height (`int`, *optional*, defaults to None): + The height in pixels of the generated video frames. If not provided, height is automatically determined + from the model configuration. + width (`int`, *optional*, defaults to None): + The width in pixels of the generated video frames. If not provided, width is automatically determined + from the model configuration. + num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps. More denoising steps usually lead to a higher quality videos at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + max_sequence_length (`int`, *optional*): + Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + output_type (`str`, *optional*, defaults to `"tensor"`): + The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or + `numpy.array` output respectively. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + custom_shift (`float`, *optional*): + Custom shifting factor to use in the flow matching framework. + + Examples: + ```py + >>> import torch + >>> from diffusers import SkyReelsV2TextToVideoPipeline + + >>> pipe = SkyReelsV2TextToVideoPipeline.from_pretrained( + ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A panda eating bamboo on a rock, 4k, detailed" + >>> video_frames = pipe(prompt, num_frames=16).frames[0] + ``` + + Returns: + [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is + returned where the first element is a list with the generated frames. + """ + # 0. Default height and width to transformer dimensions + height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 + width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 + + # 1. Check inputs + self.check_inputs( + prompt, + num_frames, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + + # 3. Determine whether to apply classifier-free guidance + do_classifier_free_guidance = guidance_scale > 1.0 + + # 4. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + # 5. Prepare timesteps + timestep_shift = None if custom_shift is None else {"shift": custom_shift} + self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Expand latents for classifier-free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # Scale model input + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # Perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # Update latents with the scheduler step + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 8. Post-processing: decode latents + video = self.decode_latents(latents) + + # 9. Convert output format + if output_type == "np": + video = video.cpu().numpy() + elif output_type == "tensor": + video = video.cpu() + + # 10. Offload all models + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 4ca47f19bc83..7be31bd0821d 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,6 +61,7 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] + _import_structure["scheduling_flow_unipc_multistep"] = ["FlowUniPCMultistepScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -163,6 +164,7 @@ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler + from .scheduling_flow_unipc_multistep import FlowUniPCMultistepScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py new file mode 100644 index 000000000000..3f38ec2de7e2 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py @@ -0,0 +1,721 @@ +# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models, + adapted for flow matching. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts the + flow of the diffusion process. + shift (`float`, defaults to 1.0): + Scaling factor for time shifting in flow matching. + use_dynamic_shifting (`bool`, defaults to False): + Whether to use dynamic time shifting based on image resolution. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for sampling. If `None`, default values are used. + mu (`float`, *optional*): + Value for dynamic shifting based on image resolution. Required when `use_dynamic_shifting=True`. + shift (`float`, *optional*): + Scaling factor for time shifting. Overrides config value if provided. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def _sigma_to_alpha_sigma_t(self, sigma): + # Compute alpha, sigma_t from sigma + alpha = torch.sigmoid(-sigma) + sigma_t = torch.sqrt((1 - alpha**2) / alpha**2) + return alpha, sigma_t + + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return mu * t / (mu + (sigma - mu) * t) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the flow matching framework expects. + + Args: + model_output (`torch.Tensor`): direct output from the model. + sample (`torch.Tensor`, *optional*): current instance of sample being created. + + Returns: + `torch.Tensor`: converted model output for the flow matching framework. + """ + # We dynamically set the scheduler to the correct inference steps + if self.config.prediction_type == "flow_prediction": + sigma = self.sigmas[self._step_index] + t = self.timesteps[self._step_index].to(model_output.device, dtype=model_output.dtype) + t = t / self.config.num_train_timesteps + + # Compute alpha, sigma_t from sigma + alpha, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + alpha = alpha.to(model_output.device, dtype=model_output.dtype) + sigma_t = sigma_t.to(model_output.device, dtype=model_output.dtype) + + if self.predict_x0: + if self.config.thresholding: + model_output = self._threshold_sample(model_output) + x0_pred = model_output + derivative = (sample - alpha * x0_pred) / sigma_t + else: + derivative = model_output + return derivative + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be flow_prediction for {self.__class__}" + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + Multistep universal `P` for building the predictor solver. + + Args: + model_output (`torch.Tensor`): + Direct output from the model. + sample (`torch.Tensor`, *optional*): + Current instance of sample being created. + order (`int`, *optional*): + Order of the solver. If `None`, it will be set based on scheduler configuration. + + Returns: + `torch.Tensor`: The predicted sample for the predictor solver. + """ + if order is None: + order = self.config.solver_order + + model_output = self.convert_model_output(model_output, sample=sample) + + # For P(x_{t+1}, x_t, x, x_{t-1} + univeral_coeff * ds_{t-1}) + # DPMSolver only need to save x_{t-1}, x_t, and ds_{t-1} and the higher order. + # We reuse the concept of UniPC and DPMSolver in the `uni_p*_update` function + self.model_outputs.append(model_output.to(sample.dtype)) + self.model_outputs.pop(0) + + time_step = self.timesteps[self._step_index].to(sample.device, model_output.dtype) + prev_time_step = self.timesteps[self._step_index + 1].to(sample.device, model_output.dtype) + + if self._step_index >= len(self.timesteps): + raise ValueError("Requested prediction step cannot advance any further. You cannot advance further.") + + # current_sigma = self.sigmas[self._step_index].to(sample.device, model_output.dtype) + dt = prev_time_step - time_step + + # 1. P(x_{t+1}, x_t, ds_t) + # Define discretized time and compute the time difference + model_output_dagger = model_output + # time_step_array = torch.tensor([1.0, time_step, time_step**2, time_step**3]) + # prev_time_step_array = torch.tensor([1.0, prev_time_step, prev_time_step**2, prev_time_step**3]) + + if order == 1: # predictor with euler steps + if self.config.solver_type == "bh1": + x_t = sample + dt * model_output_dagger + elif self.config.solver_type == "bh2": + x_t = sample + dt * model_output_dagger + else: + self.timestep_list.append(time_step) + self.timestep_list.pop(0) + + # Order matching the UniPC + if 2 <= order <= 3: + current_model_output = model_output_dagger + prev_model_output = self.model_outputs[-2] + + time_coe = dt + + rhos = self.sigmas[self._step_index - 1] / self.sigmas[self._step_index] + rhos = rhos.to(sample.device, model_output.dtype) + + # t -> t + 1 + if order == 2: + # Bh1 + if self.config.solver_type == "bh1": + # Taylor expansion + h_tau = time_coe + h_phi = time_coe + + # 2nd order expansion + x_t = ( + sample + + h_phi * current_model_output + + 0.5 * h_phi**2 * (current_model_output - prev_model_output) / dt + ) + elif self.config.solver_type == "bh2": + # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver + h_t, h_t_1 = time_step, self.timestep_list[-2] + # r = rhos + + # prediction: 2nd order expansion from UniPC paper + h = h_t - h_t_1 + x_t = ( + sample + + h * current_model_output + - 0.5 * h**2 * (current_model_output - prev_model_output) / (h_t - h_t_1) + ) + elif order == 3: + prev_prev_model_output = self.model_outputs[-3] + h_t, h_t_1, h_t_2 = time_step, self.timestep_list[-2], self.timestep_list[-3] + # r, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] + _, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] + r_1 = r_1.to(sample.device, model_output.dtype) + + # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver + if self.config.solver_type == "bh1": + # Taylor expansion + h_tau = time_coe + h_phi = time_coe + h = h_t_1 - h_t_2 + derivative2 = (current_model_output - prev_model_output) / (h_t - h_t_1) + derivative3 = ( + derivative2 - (prev_model_output - prev_prev_model_output) / (h_t_1 - h_t_2) + ) / (h_t - h_t_2) + x_t = ( + sample + + h_tau * current_model_output + + 0.5 * h_tau**2 * derivative2 + + (1.0 / 6.0) * h_tau**3 * derivative3 + ) + elif self.config.solver_type == "bh2": + # From UniC paper: https://github.com/wl-zhao/UniPC + h1 = h_t - h_t_1 + h2 = h_t_1 - h_t_2 + h_left_01 = h_t - h_t_1 + h_left_12 = h_t_1 - h_t_2 + h_left_02 = h_t - h_t_2 + taylor1 = current_model_output + taylor2 = (current_model_output - prev_model_output) / h_left_01 + taylor3 = (taylor2 - (prev_model_output - prev_prev_model_output) / h_left_12) / h_left_02 + x_t = sample + h1 * taylor1 + h1**2 * taylor2 / 2 + h1**2 * h2 * taylor3 / 6 + + else: + raise NotImplementedError(f"Multistep UniCI predict with order {order} is not implemented yet.") + + # The format of predictor solvers in DPM-Solver. + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + Multistep universal `C` for updating the corrector solver. + + Args: + this_model_output (`torch.Tensor`): + Direct output from the model of the current scale. + last_sample (`torch.Tensor`, *optional*): + Sample from the previous scale. + this_sample (`torch.Tensor`, *optional*): + Current sample. + order (`int`, *optional*): + Order of the solver. If `None`, it will be set based on scheduler configuration. + + Returns: + `torch.Tensor`: The updated sample for the corrector solver. + """ + if order is None: + order = self.config.solver_order + # Similar structure as the universal P + # Convert to flow matching format + this_model_output = self.convert_model_output(this_model_output, sample=this_sample) + + if self._step_index > self.num_inference_steps - 1: + prev_time_step = torch.tensor(0.0) + else: + prev_time_step = self.timesteps[self._step_index + 1].to(this_sample.device, this_model_output.dtype) + + time_step = self.timesteps[self._step_index].to(this_sample.device, this_model_output.dtype) + dt = prev_time_step - time_step + + if order == 1: + model_output_processor = this_model_output + # Model output is scaled if we used noise with multiscale + # Corrector + if self.config.solver_type == "bh1": + # Normal euler step to compute corrector (UniC) + x_t = last_sample + dt * model_output_processor + elif self.config.solver_type == "bh2": + # Midpoint method for Heun's 2nd order method + midpoint_model_output = 0.5 * (model_output_processor + this_model_output) + # Runge-Kutta 2nd order + x_t = last_sample + dt * midpoint_model_output + else: # order > 1: + self.timestep_list.append(time_step) + self.timestep_list.pop(0) + self.model_outputs.append(this_model_output.to(last_sample.dtype)) + self.model_outputs.pop(0) + + current_model_output = this_model_output + prev_model_output = self.model_outputs[-2] + + time_coe = dt + + rhos = self.sigmas[self._step_index - 1] / self.sigmas[self._step_index] + rhos = rhos.to(last_sample.device, last_sample.dtype) + + # t -> t + 1 + if order == 2: + # Bh1 + if self.config.solver_type == "bh1": + # Taylor expansion + h_tau = time_coe + h_phi = time_coe + + # 2nd order expansion + x_t = ( + last_sample + + h_phi * current_model_output + + 0.5 * h_phi**2 * (current_model_output - prev_model_output) / dt + ) + elif self.config.solver_type == "bh2": + h_t, h_t_1 = time_step, self.timestep_list[-2] + # r = rhos + h = h_t - h_t_1 + x_t = ( + last_sample + + h * current_model_output + - 0.5 * h**2 * (current_model_output - prev_model_output) / (h_t - h_t_1) + ) + elif order == 3: + prev_prev_model_output = self.model_outputs[-3] + h_t, h_t_1, h_t_2 = time_step, self.timestep_list[-2], self.timestep_list[-3] + # r, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] + _, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] + r_1 = r_1.to(last_sample.device, last_sample.dtype) + + # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver + if self.config.solver_type == "bh1": + # Taylor expansion + h_tau = time_coe + h_phi = time_coe + h = h_t_1 - h_t_2 + derivative2 = (current_model_output - prev_model_output) / (h_t - h_t_1) + derivative3 = (derivative2 - (prev_model_output - prev_prev_model_output) / (h_t_1 - h_t_2)) / ( + h_t - h_t_2 + ) + x_t = ( + last_sample + + h_tau * current_model_output + + 0.5 * h_tau**2 * derivative2 + + (1.0 / 6.0) * h_tau**3 * derivative3 + ) + elif self.config.solver_type == "bh2": + # From UniC paper: https://github.com/wl-zhao/UniPC + h1 = h_t - h_t_1 + h2 = h_t_1 - h_t_2 + h_left_01 = h_t - h_t_1 + h_left_12 = h_t_1 - h_t_2 + h_left_02 = h_t - h_t_2 + taylor1 = current_model_output + taylor2 = (current_model_output - prev_model_output) / h_left_01 + taylor3 = (taylor2 - (prev_model_output - prev_prev_model_output) / h_left_12) / h_left_02 + x_t = last_sample + h1 * taylor1 + h1**2 * taylor2 / 2 + h1**2 * h2 * taylor3 / 6 + else: + raise NotImplementedError(f"Multistep UniCI predict with order {order} is not implemented yet.") + + # The format of corrector solvers in DPM-Solver. + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + if self.begin_index is not None: + indices = indices[self.begin_index :] + + if len(indices) == 0: + raise ValueError( + f"could not find timestep {timestep} from self.timesteps, Currently, self.timesteps have shape {self.timesteps.shape}, " + f"and set scale to {self.config.set_scale}" + ) + return indices[0].item() + + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = self.index_for_timestep(timestep) + self._step_index = step_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`torch.Tensor` or `int`): + The discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # Initialize the step_index if performing the first step + if self._step_index is None: + if self.begin_index is None: + self._step_index = 0 + else: + self._step_index = self.begin_index + self._init_step_index(timestep) + + # Upcast sample and model_output to float32 for self.sigmas + sample = sample.to(self.sigmas.dtype) + model_output = model_output.to(self.sigmas.dtype) + + # Apply predicctor (P): x_t -> x_t-1 + if self.config.lower_order_final and self._step_index > self.num_inference_steps - 4: + # For DPM-Solver++(2S), we use lower order solver for the final steps to stabilize the long time inference + # it is equivalent to use our coefficients but change the order + target_order = min(int(self.config.solver_order - 1), 2) + + # 3rd order method + 2nd order + 1st order + if self.config.solver_order > 2 and self._step_index > self.num_inference_steps - 2: + # set order to 1 for the final step + target_order = min(int(target_order - 1), 2) + + # Switch to lower order for DPM-Solver++(2S) in the final steps to stabilize the long time inference + lower_order_predict = self.multistep_uni_p_bh_update( + model_output=model_output, sample=sample, order=target_order + ) + next_sample = lower_order_predict + else: + this_predict = self.multistep_uni_p_bh_update( + model_output=model_output, sample=sample, order=self.config.solver_order + ) + next_sample = this_predict + + # Apply a corrector + if self._step_index not in self.config.disable_corrector: + # UniCPC + # predictor: x_1 -> x_t-1, corrector: x_1 -> x_t-1 + if self.solver_p: + # solver_p_output = self.solver_p.step(model_output, timestep, sample, return_dict=False)[0] + _ = self.solver_p.step(model_output, timestep, sample, return_dict=False)[0] + next_sample = self.multistep_uni_c_bh_update( + this_model_output=model_output, + last_sample=next_sample, + this_sample=sample, + order=self.config.solver_order, + ) + + # update step index + self._step_index += 1 + self.last_sample = sample + + if not return_dict: + return (next_sample,) + + return SchedulerOutput(prev_sample=next_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if timesteps.device != sigmas.device: + timesteps = timesteps.to(sigmas.device) + if timesteps.dtype != torch.int64: + timesteps = timesteps.to(torch.int64) + + schedule_timesteps = self.timesteps + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps From 607b5baafc4942ee22a050c97665b3723c30847a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 7 May 2025 22:10:21 +0300 Subject: [PATCH 002/264] up --- src/diffusers/pipelines/skyreels_v2/__init__.py | 4 ++-- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- .../skyreels_v2/pipeline_skyreels_v2_image_to_video.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py index 62a8f98feeb9..34086b8a8ffe 100644 --- a/src/diffusers/pipelines/skyreels_v2/__init__.py +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -21,10 +21,10 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["pipeline_skyreels_v2_text_to_video"] = ["SkyReelsV2TextToVideoPipeline"] - _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] + _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] + _import_structure["pipeline_skyreels_v2_text_to_video"] = ["SkyReelsV2TextToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8757e4fba47c..32aee85fb6ba 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -69,7 +69,7 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - transformer ([`SkyReelsV2TransformerModel`]): + transformer ([`WanTransformer3DModel`]): A SkyReels-V2 transformer model for diffusion with diffusion forcing capability. scheduler ([`FlowUniPCMultistepScheduler`]): A scheduler to be used in combination with the transformer to denoise the encoded video latents. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index 8844cc795ca1..e1f8aff98b01 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -65,7 +65,7 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. - transformer ([`SkyReelsV2TransformerModel`]): + transformer ([`WanTransformer3DModel`]): A SkyReels-V2 transformer model for diffusion. scheduler ([`FlowUniPCMultistepScheduler`]): A scheduler to be used in combination with the transformer to denoise the encoded video latents. From 3ccf201e816b0a46954e4426a723629f3f70b21e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 8 May 2025 10:50:40 +0300 Subject: [PATCH 003/264] second draft --- src/diffusers/pipelines/__init__.py | 14 + .../pipelines/skyreels_v2/__init__.py | 1 - .../pipeline_skyreels_v2_diffusion_forcing.py | 660 +++++++++--------- .../pipeline_skyreels_v2_image_to_video.py | 406 ++++++----- .../pipeline_skyreels_v2_text_to_video.py | 424 +++++------ 5 files changed, 779 insertions(+), 726 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 90ff30279f96..2a9997f26a6e 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -30,6 +30,7 @@ "ledits_pp": [], "marigold": [], "pag": [], + "skyreels_v2": [], "stable_diffusion": [], "stable_diffusion_xl": [], } @@ -361,6 +362,13 @@ "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"] + _import_structure["skyreels_v2"].extend( + [ + "SkyReelsV2DiffusionForcingPipeline", + "SkyReelsV2ImageToVideoPipeline", + "SkyReelsV2TextToVideoPipeline", + ] + ) try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -823,6 +831,12 @@ SpectrogramDiffusionPipeline, ) + from .skyreels_v2 import ( + SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2ImageToVideoPipeline, + SkyReelsV2TextToVideoPipeline, + ) + else: import sys diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py index 34086b8a8ffe..9413b0b3339a 100644 --- a/src/diffusers/pipelines/skyreels_v2/__init__.py +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -22,7 +22,6 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] - _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] _import_structure["pipeline_skyreels_v2_text_to_video"] = ["SkyReelsV2TextToVideoPipeline"] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 32aee85fb6ba..b9c4664ca9a2 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -19,6 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VideoProcessor from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowUniPCMultistepScheduler from ...utils import ( @@ -32,25 +33,51 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -EXAMPLE_DOC_STRING = """ + +EXAMPLE_DOC_STRING = """\ Examples: ```py >>> import torch >>> import PIL.Image >>> from diffusers import SkyReelsV2DiffusionForcingPipeline + >>> from diffusers.utils import export_to_video, load_image >>> # Load the pipeline >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") - >>> # Prepare conditioning frames (list of PIL Images or a tensor of shape [frames, height, width, channels]) - >>> frames = [PIL.Image.open(f"frame_{i}.jpg").convert("RGB") for i in range(5)] + >>> # Prepare conditioning frames (list of PIL Images) + >>> # Example: Condition on frames 0, 24, 48, 72 for a 97-frame video + >>> frame_0 = load_image("./frame_0.png") # Placeholder paths + >>> frame_24 = load_image("./frame_24.png") + >>> frame_48 = load_image("./frame_48.png") + >>> frame_72 = load_image("./frame_72.png") + >>> conditioning_frames = [frame_0, frame_24, frame_48, frame_72] + >>> # Create mask: 1 for conditioning frames, 0 for frames to generate - >>> mask = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] + >>> num_frames = 97 # Match the default + >>> conditioning_frame_mask = [0] * num_frames + >>> # Example conditioning indices for a 97-frame video + >>> conditioning_indices = [0, 24, 48, 72] + >>> for idx in conditioning_indices: + ... if idx < num_frames: # Check bounds + ... conditioning_frame_mask[idx] = 1 + >>> prompt = "A person walking in the park" - >>> video = pipe(prompt, conditioning_frames=frames, conditioning_frame_mask=mask, num_frames=16).frames[0] + >>> video = pipe( + ... prompt=prompt, + ... conditioning_frames=conditioning_frames, + ... conditioning_frame_mask=conditioning_frame_mask, + ... num_frames=num_frames, + ... height=544, + ... width=960, + ... num_inference_steps=30, + ... guidance_scale=6.0, + ... custom_shift=8.0, + ... ).frames + >>> export_to_video(video, "skyreels_v2_df.mp4") ``` """ @@ -73,6 +100,8 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline): A SkyReels-V2 transformer model for diffusion with diffusion forcing capability. scheduler ([`FlowUniPCMultistepScheduler`]): A scheduler to be used in combination with the transformer to denoise the encoded video latents. + video_processor ([`VideoProcessor`]): + Processor for post-processing generated videos (e.g., tensor to numpy array). """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -85,6 +114,7 @@ def __init__( tokenizer: CLIPTokenizer, transformer: WanTransformer3DModel, scheduler: FlowUniPCMultistepScheduler, + video_processor: VideoProcessor, ): super().__init__() self.register_modules( @@ -93,6 +123,7 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, + video_processor=video_processor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -232,111 +263,61 @@ def encode_prompt( return prompt_embeds - def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - """ - Decode the generated latent sample using the VAE to produce video frames. - - Args: - latents (`torch.Tensor`): Generated latent samples from the diffusion process. - - Returns: - `torch.Tensor`: Decoded video frames. - """ - video_length = latents.shape[2] - - latents = 1 / self.vae.config.scaling_factor * latents # scale latents - - # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] - latents = latents.permute(0, 2, 1, 3, 4).reshape( - latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] - ) - - # Decode all frames + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + # AutoencoderKLWan expects B, C, F, H, W latents directly video = self.vae.decode(latents).sample - - # Reshape back to [batch, frames, channels, height, width] - video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) - - # Rescale video from [-1, 1] to [0, 1] + video = video.permute(0, 2, 1, 3, 4) video = (video / 2 + 0.5).clamp(0, 1) + return video - # Rescale to pixel values - video = (video * 255).to(torch.uint8) - - # Permute channels to [batch, frames, height, width, channels] - return video.permute(0, 1, 3, 4, 2) - - def encode_frames(self, frames: Union[List[PIL.Image.Image], torch.Tensor, np.ndarray]) -> torch.Tensor: + def encode_frames(self, frames: Union[List[PIL.Image.Image], torch.Tensor]) -> torch.Tensor: """ - Encode the conditioning frames to latent space using VAE. + Encodes conditioning frames into VAE latent space. Args: - frames (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`): - List of frames or tensor/array containing frames to encode. + frames (`List[PIL.Image.Image]` or `torch.Tensor`): + The conditioning frames (batch, frames, channels, height, width) or list of PIL images. Assumes frames + are already preprocessed (e.g., correct size, range [-1, 1] if tensor). Returns: - `torch.Tensor`: Latent representation of the input frames. + `torch.Tensor`: Latent representations of the frames (batch, channels, latent_frames, height, width). """ - device = self._execution_device - dtype = self.vae.dtype - if isinstance(frames, list): - # Convert list of PIL Images to tensor [frames, channels, height, width] - processed_frames = [] - for frame in frames: - if isinstance(frame, PIL.Image.Image): - frame = np.array(frame).astype(np.float32) / 127.5 - 1.0 - frame = torch.from_numpy(frame).permute(2, 0, 1) - processed_frames.append(frame) - frames_tensor = torch.stack(processed_frames) - - elif isinstance(frames, np.ndarray): - # Convert numpy array to tensor - if frames.ndim == 4: # [frames, height, width, channels] - frames = frames.astype(np.float32) / 127.5 - 1.0 - frames_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2) # [frames, channels, height, width] - else: - raise ValueError( - f"Unexpected numpy array shape: {frames.shape}, expected [frames, height, width, channels]" - ) - + # Assume list of PIL Images, needs preprocessing similar to VAE requirements + # Note: This uses a basic preprocessing, might need alignment with VaeImageProcessor + frames_np = np.stack([np.array(frame.convert("RGB")) for frame in frames]) + frames_tensor = torch.from_numpy(frames_np).float() / 127.5 - 1.0 # Range [-1, 1] + frames_tensor = frames_tensor.permute( + 0, 3, 1, 2 + ) # -> (batch*frames, channels, H, W) if flattened? No, needs batch dim. + # Let's assume the input list is for a SINGLE batch item's frames. + # Needs shape (batch=1, frames, channels, H, W) -> permute to (batch=1, channels, frames, H, W) + frames_tensor = frames_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4) elif isinstance(frames, torch.Tensor): - if frames.ndim == 4: - if frames.shape[1] == 3: # [frames, channels, height, width] - frames_tensor = frames - elif frames.shape[3] == 3: # [frames, height, width, channels] - frames_tensor = frames.permute(0, 3, 1, 2) - else: - raise ValueError(f"Unexpected tensor shape: {frames.shape}, cannot determine channel dimension") + # Assume input tensor is already preprocessed and has shape (batch, frames, channels, H, W) or similar + # Ensure range [-1, 1] + if frames.min() >= 0.0 and frames.max() <= 1.0: + frames = 2.0 * frames - 1.0 + # Permute to (batch, channels, frames, H, W) + if frames.ndim == 5 and frames.shape[2] == 3: # Check if channels is dim 2 + frames_tensor = frames.permute(0, 2, 1, 3, 4) + elif frames.ndim == 5 and frames.shape[1] == 3: # Check if channels is dim 1 + frames_tensor = frames # Already in correct channel order else: - raise ValueError(f"Unexpected tensor shape: {frames.shape}, expected 4D tensor") - - # Ensure pixel values are in range [-1, 1] - if frames_tensor.min() >= 0 and frames_tensor.max() <= 1: - frames_tensor = 2.0 * frames_tensor - 1.0 - elif frames_tensor.min() >= 0 and frames_tensor.max() <= 255: - frames_tensor = frames_tensor / 127.5 - 1.0 + raise ValueError("Input tensor shape not recognized. Expected (B, F, C, H, W) or (B, C, F, H, W).") else: - raise ValueError(f"Unsupported frame input type: {type(frames)}") + raise TypeError("`conditioning_frames` must be a list of PIL Images or a torch Tensor.") - # Move to device and correct dtype - frames_tensor = frames_tensor.to(device=device, dtype=dtype) + frames_tensor = frames_tensor.to(device=self.device, dtype=self.vae.dtype) - # Process in batches if there are many frames, to avoid OOM - batch_size = 8 # reasonable batch size for VAE encoding - latents = [] + # Encode frames using VAE + # Note: VAE encode expects (batch, channels, frames, height, width)? Check AutoencoderKLWan docs/code + # AutoencoderKLWan._encode takes (B, C, F, H, W) + conditioning_latents = self.vae.encode(frames_tensor).latent_dist.sample() + conditioning_latents = conditioning_latents * self.vae.config.scaling_factor - for i in range(0, frames_tensor.shape[0], batch_size): - batch = frames_tensor[i : i + batch_size] - with torch.no_grad(): - batch_latents = self.vae.encode(batch).latent_dist.sample() - batch_latents = batch_latents * self.vae.config.scaling_factor - latents.append(batch_latents) - - # Concatenate all batches - latents = torch.cat(latents, dim=0) - - return latents + # Expected output shape: (batch, channels, latent_frames, latent_height, latent_width) + return conditioning_latents def prepare_latents_with_forcing( self, @@ -349,138 +330,169 @@ def prepare_latents_with_forcing( device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - conditioning_latents: Optional[torch.Tensor] = None, + conditioning_latents_sparse: Optional[torch.Tensor] = None, conditioning_frame_mask: Optional[List[int]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Prepare latent variables for diffusion forcing. + ) -> Tuple[torch.Tensor, torch.Tensor, List[bool]]: + r""" + Prepare latent variables, incorporating conditioning frames based on the mask. Args: batch_size (`int`): Number of samples to generate. num_channels_latents (`int`): Number of channels in the latent space. - num_frames (`int`): Number of video frames to generate. - height (`int`): Height of the generated images in pixels. - width (`int`): Width of the generated images in pixels. + num_frames (`int`): Total number of video frames to generate. + height (`int`): Height of the generated video in pixels. + width (`int`): Width of the generated video in pixels. dtype (`torch.dtype`): Data type of the latent variables. device (`torch.device`): Device to generate the latents on. - generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. - latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. - conditioning_latents (`torch.Tensor`, *optional*): Latent representations of conditioning frames. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): Generator(s) for noise. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. + conditioning_latents_sparse (`torch.Tensor`, *optional*): Latent representations of conditioning frames. + Shape: (batch, channels, num_cond_latent_frames, latent_H, latent_W). conditioning_frame_mask (`List[int]`, *optional*): - Binary mask indicating which frames are conditioning frames. + Mask indicating which output frames are conditioned (1) or generated (0). Length must match + `num_frames`. Returns: - `Tuple[torch.Tensor, torch.Tensor]`: Prepared initial latent variables and forcing frame mask. + `Tuple[torch.Tensor, torch.Tensor, List[bool]]`: + - Prepared initial latent variables (noise). + - Mask tensor in latent space indicating regions to generate (True) vs conditioned (False). + - Boolean list representing the mask at the latent frame level (True=Conditioned). """ - # Check if we have all required inputs for diffusion forcing - if conditioning_frame_mask is None: - raise ValueError("conditioning_frame_mask is required for diffusion forcing") - - if conditioning_latents is None: - raise ValueError("conditioning_latents are required for diffusion forcing") - - # Ensure mask has the right length - if len(conditioning_frame_mask) != num_frames: - raise ValueError( - f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must match num_frames ({num_frames})" - ) - - # Count conditioning frames in the mask - num_cond_frames = sum(conditioning_frame_mask) - - # Check if conditioning_latents has correct number of frames - if conditioning_latents.shape[0] != num_cond_frames: - raise ValueError( - f"Number of conditioning frames ({conditioning_latents.shape[0]}) must match " - f"number of 1s in conditioning_frame_mask ({num_cond_frames})" - ) - - # Shape for full video latents - shape = ( + # Calculate latent spatial shape + shape_spatial = ( batch_size, num_channels_latents, - num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + # Calculate temporal downsampling factor + if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) + temporal_downsample_factor = 2**num_true_temporal_downsamples + else: + temporal_downsample_factor = 4 + logger.warning( + "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." ) - # Generate or use provided latents + # Calculate number of latent frames required for the full output sequence + num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 + shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) + + # Create initial noise latents if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + initial_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") - latents = latents.to(device) - - # Scale initial noise by the standard deviation - latents = latents * self.scheduler.init_noise_sigma - - # Create forcing mask tensor [batch, 1, frames, 1, 1] - forcing_mask = torch.tensor(conditioning_frame_mask, device=device, dtype=dtype) - forcing_mask = forcing_mask.view(1, 1, num_frames, 1, 1).expand(batch_size, 1, -1, 1, 1) - - # Insert conditioning latents at the correct positions based on mask - cond_idx = 0 - for frame_idx, is_cond in enumerate(conditioning_frame_mask): - if is_cond: - # Replace the random noise with the encoded conditioning frame - latents[:, :, frame_idx : frame_idx + 1] = ( - conditioning_latents[cond_idx : cond_idx + 1].unsqueeze(0).expand(batch_size, -1, -1, -1, -1) + initial_latents = latents.to(device) + + # Create latent mask + latent_mask_list_bool = [False] * num_latent_frames # Default: All False (generate) + if conditioning_latents_sparse is not None and conditioning_frame_mask is not None: + if len(conditioning_frame_mask) != num_frames: + raise ValueError("Length of conditioning_frame_mask must equal num_frames.") + + # Correct mapping from frame mask to latent frame mask + num_conditioned_latents_expected = 0 + for j in range(num_latent_frames): + start_frame_idx = j * temporal_downsample_factor + end_frame_idx = min(start_frame_idx + temporal_downsample_factor, num_frames) + # Check if any original frame corresponding to this latent frame is a conditioning frame (mask=1) + is_latent_conditioned = any( + conditioning_frame_mask[k] == 1 for k in range(start_frame_idx, end_frame_idx) + ) + latent_mask_list_bool[j] = ( + is_latent_conditioned # True if this latent frame corresponds to a conditioned frame ) - cond_idx += 1 + if is_latent_conditioned: + num_conditioned_latents_expected += 1 + + # Validate the number of conditioning latents provided vs expected + if conditioning_latents_sparse.shape[2] != num_conditioned_latents_expected: + logger.warning( + f"Number of provided conditioning latents (frame dim: {conditioning_latents_sparse.shape[2]}) does not match " + f"the number of latent frames marked for conditioning ({num_conditioned_latents_expected}) based on the mask and stride. " + f"Ensure encode_frames provides latents only for the necessary frames." + ) + # This indicates a potential mismatch that could cause errors later. + + # Create the tensor mask for computations + # latent_mask_list_bool is True for conditioned frames + # We need a mask where True means GENERATE (inpaint area) + latent_mask_tensor_cond = torch.tensor( + latent_mask_list_bool, device=device, dtype=torch.bool + ) # True=Conditioned + latent_mask = ~latent_mask_tensor_cond.reshape(1, 1, num_latent_frames, 1, 1).expand_as( + initial_latents + ) # True = Generate + else: + # No conditioning, generate everything. Mask is all True (generate). + latent_mask = torch.ones_like(initial_latents, dtype=torch.bool) + latent_mask_list_bool = [False] * num_latent_frames # No frames are conditioned + + # Scale the initial noise by the standard deviation required by the scheduler + initial_latents = initial_latents * self.scheduler.init_noise_sigma - return latents, forcing_mask + # Return initial noise, mask (True=Generate), and boolean list (True=Conditioned) + return initial_latents, latent_mask, latent_mask_list_bool def check_conditioning_inputs( self, - conditioning_frames: Union[List[PIL.Image.Image], torch.Tensor, np.ndarray], - conditioning_frame_mask: List[int], + conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor]], + conditioning_frame_mask: Optional[List[int]], num_frames: int, ): - """Check validity of conditioning inputs.""" - # Validate mask length - if len(conditioning_frame_mask) != num_frames: - raise ValueError( - f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must match num_frames ({num_frames})" - ) - - # Validate mask values - if not all(x in [0, 1] for x in conditioning_frame_mask): - raise ValueError("conditioning_frame_mask must only contain 0s and 1s") - - # Count conditioning frames - num_conditioning_frames = sum(conditioning_frame_mask) - - # Validate number of conditioning frames - if isinstance(conditioning_frames, list): - if len(conditioning_frames) != num_conditioning_frames: + if conditioning_frames is None and conditioning_frame_mask is not None: + raise ValueError("`conditioning_frame_mask` provided without `conditioning_frames`.") + if conditioning_frames is not None and conditioning_frame_mask is None: + raise ValueError("`conditioning_frames` provided without `conditioning_frame_mask`.") + + if conditioning_frames is not None: + if not isinstance(conditioning_frame_mask, list) or not all( + isinstance(i, int) for i in conditioning_frame_mask + ): + raise TypeError("`conditioning_frame_mask` must be a list of integers (0 or 1).") + if len(conditioning_frame_mask) != num_frames: raise ValueError( - f"Number of conditioning frames ({len(conditioning_frames)}) must match " - f"number of 1s in conditioning_frame_mask ({num_conditioning_frames})" - ) - elif isinstance(conditioning_frames, (torch.Tensor, np.ndarray)): - if conditioning_frames.shape[0] != num_conditioning_frames: - raise ValueError( - f"Number of conditioning frames ({conditioning_frames.shape[0]}) must match " - f"number of 1s in conditioning_frame_mask ({num_conditioning_frames})" + f"`conditioning_frame_mask` length ({len(conditioning_frame_mask)}) must equal `num_frames` ({num_frames})." ) + if not all(m in [0, 1] for m in conditioning_frame_mask): + raise ValueError("`conditioning_frame_mask` must only contain 0s and 1s.") + + num_masked_frames = sum(conditioning_frame_mask) + + if isinstance(conditioning_frames, list): + if not all(isinstance(f, PIL.Image.Image) for f in conditioning_frames): + raise TypeError("If `conditioning_frames` is a list, it must contain only PIL Images.") + if len(conditioning_frames) != num_masked_frames: + raise ValueError( + f"Number of `conditioning_frames` ({len(conditioning_frames)}) must equal the number of 1s in `conditioning_frame_mask` ({num_masked_frames})." + ) + elif isinstance(conditioning_frames, torch.Tensor): + # Assuming tensor shape is (num_masked_frames, C, H, W) or (B, num_masked_frames, C, H, W) etc. + # A simple check on the frame dimension assuming it's the first or second dim after batch + if not ( + conditioning_frames.shape[0] == num_masked_frames + or (conditioning_frames.ndim > 1 and conditioning_frames.shape[1] == num_masked_frames) + ): + # This check is basic and might need refinement based on expected tensor layout + logger.warning( + f"Number of frames in `conditioning_frames` tensor ({conditioning_frames.shape}) does not seem to match the number of 1s in `conditioning_frame_mask` ({num_masked_frames}). Ensure tensor shape is correct." + ) + else: + raise TypeError("`conditioning_frames` must be a List[PIL.Image.Image] or torch.Tensor.") @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, - conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor, np.ndarray]] = None, + conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None, conditioning_frame_mask: Optional[List[int]] = None, - num_frames: int = 16, + num_frames: int = 97, height: Optional[int] = None, width: Optional[int] = None, - num_inference_steps: int = 25, - guidance_scale: float = 5.0, + num_inference_steps: int = 30, + guidance_scale: float = 6.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -488,221 +500,217 @@ def __call__( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, - output_type: Optional[str] = "tensor", + output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = None, + custom_shift: Optional[float] = 8.0, ) -> Union[SkyReelsV2PipelineOutput, Tuple]: - """ - The call function to the pipeline for generation with diffusion forcing. + r""" + Generate video frames conditioned on text prompts and optionally on specific input frames (diffusion forcing). Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - conditioning_frames (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`, *optional*): - Frames to use as conditioning points during video generation. Should be provided as a list of PIL - images, or as a tensor/array of shape [num_cond_frames, height, width, channels] or [num_cond_frames, - channels, height, width]. + The prompt or prompts to guide video generation. If not defined, prompt_embeds must be. + conditioning_frames (`List[PIL.Image.Image]` or `torch.Tensor`, *optional*): + Frames to condition on. Must be provided if `conditioning_frame_mask` is provided. If a list, should + contain PIL Images. If a Tensor, assumes shape compatible with VAE input after batching. conditioning_frame_mask (`List[int]`, *optional*): - Binary mask indicating which frames are conditioning frames (1) and which are to be generated (0). Must - have the same length as num_frames and the same number of 1s as the number of conditioning_frames. - num_frames (`int`, *optional*, defaults to 16): - The number of video frames to generate. - height (`int`, *optional*, defaults to None): - The height in pixels of the generated video frames. If not provided, height is automatically determined - from the model configuration. - width (`int`, *optional*, defaults to None): - The width in pixels of the generated video frames. If not provided, width is automatically determined - from the model configuration. - num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps. More denoising steps usually lead to higher quality videos at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 5.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + A list of 0s and 1s with length `num_frames`. 1 indicates a conditioning frame, 0 indicates a frame to + generate. + num_frames (`int`, *optional*, defaults to 97): + The total number of frames to generate in the video sequence. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated video. + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 6.0): + Guidance scale for classifier-free guidance. Enabled when > 1. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + Negative prompts for CFG. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. + Number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + PyTorch Generator object(s) for deterministic generation. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. + Pre-generated initial latents (noise). If provided, shape should match expected latent shape. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. + Pre-generated text embeddings. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + Pre-generated negative text embeddings. max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. - output_type (`str`, *optional*, defaults to `"tensor"`): - The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or - `numpy.array` output respectively. + Maximum sequence length for tokenizer. Defaults to model max length (e.g., 77). + output_type (`str`, *optional*, defaults to `"np"`): + Output format: `"tensor"` (torch.Tensor) or `"np"` (list of np.ndarray). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.SkyReelsV2PipelineOutput`] instead of a plain tuple. + Whether to return `SkyReelsV2PipelineOutput` or a tuple. callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + Callback function called every `callback_steps` steps. callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. + Frequency of callback calls. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + Keyword arguments passed to the attention processor. custom_shift (`float`, *optional*): - Custom shifting factor to use in the flow matching framework. - - Examples: - ```py - >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2DiffusionForcingPipeline - - >>> # Load the pipeline - >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> # Prepare conditioning frames (list of PIL Images or a tensor of shape [frames, height, width, channels]) - >>> frames = [PIL.Image.open(f"frame_{i}.jpg").convert("RGB") for i in range(5)] - >>> # Create mask: 1 for conditioning frames, 0 for frames to generate - >>> mask = [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0] - >>> prompt = "A person walking in the park" - >>> video = pipe(prompt, conditioning_frames=frames, conditioning_frame_mask=mask, num_frames=16).frames[0] - ``` + Shift parameter for the `FlowUniPCMultistepScheduler`. Returns: - [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is - returned where the first element is a list with the generated frames. + [`~pipelines.skyreels_v2.pipeline_skyreels_v2_text_to_video.SkyReelsV2PipelineOutput`] or `tuple`. """ - # 0. Default height and width to transformer dimensions - height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 - width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 + # 0. Default height and width + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor # 1. Check inputs self.check_inputs( - prompt, - num_frames, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) - - # Check diffusion forcing inputs - if conditioning_frames is None or conditioning_frame_mask is None: - raise ValueError("For diffusion forcing, conditioning_frames and conditioning_frame_mask must be provided") - self.check_conditioning_inputs(conditioning_frames, conditioning_frame_mask, num_frames) + has_conditioning = conditioning_frames is not None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: + elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] - device = self._execution_device + elif isinstance(conditioning_frames, list) or isinstance(conditioning_frames, torch.Tensor): + batch_size = 1 # Assuming single batch item from frames for now + else: + raise ValueError("Cannot determine batch size.") + if has_conditioning and batch_size > 1: + logger.warning("Batch size > 1 not fully tested with diffusion forcing.") + batch_size = 1 - # 3. Determine whether to apply classifier-free guidance + device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - # 4. Encode input prompt + # 3. Encode input prompt prompt_embeds = self.encode_prompt( prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, + prompt_embeds, + negative_prompt_embeds, + max_sequence_length, ) + prompt_dtype = prompt_embeds.dtype + + # 4. Encode conditioning frames if provided + conditioning_latents_sparse = None + if has_conditioning: + conditioning_latents_sparse = self.encode_frames(conditioning_frames) + conditioning_latents_sparse = conditioning_latents_sparse.to(device=device, dtype=prompt_dtype) + # Repeat for num_videos_per_prompt + if conditioning_latents_sparse.shape[0] != batch_size * num_videos_per_prompt: + conditioning_latents_sparse = conditioning_latents_sparse.repeat_interleave( + num_videos_per_prompt, dim=0 + ) - # 5. Encode conditioning frames - conditioning_latents = self.encode_frames(conditioning_frames) - - # 6. Prepare timesteps - timestep_shift = None if custom_shift is None else {"shift": custom_shift} - self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) timesteps = self.scheduler.timesteps - # 7. Prepare latent variables with forcing + # 6. Prepare latent variables and mask num_channels_latents = self.vae.config.latent_channels - latents, forcing_mask = self.prepare_latents_with_forcing( + # Pass conditioning_latents_sparse to prepare_latents only for validation checks if needed + latents, latent_mask, latent_mask_list_bool = self.prepare_latents_with_forcing( batch_size * num_videos_per_prompt, num_channels_latents, num_frames, height, width, - prompt_embeds.dtype, + prompt_dtype, device, generator, - latents, - conditioning_latents, - conditioning_frame_mask, + latents=latents, + conditioning_latents_sparse=conditioning_latents_sparse, + conditioning_frame_mask=conditioning_frame_mask, ) + # latents = initial noise; latent_mask = True means generate; latent_mask_list_bool = True means conditioned - # 8. Denoising loop + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Expand latents for classifier-free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + # Prepare the known conditioned part (noised) + noised_conditioning_latents_full = None + if has_conditioning: + # Create a full-shaped tensor for the noised conditioning latents + full_conditioning_latents = torch.zeros_like(latents) + sparse_idx_counter = 0 + for latent_idx, is_conditioned in enumerate(latent_mask_list_bool): + if is_conditioned: # True means it *was* a conditioning frame + if sparse_idx_counter < conditioning_latents_sparse.shape[2]: + full_conditioning_latents[:, :, latent_idx, :, :] = conditioning_latents_sparse[ + :, :, sparse_idx_counter, :, : + ] + sparse_idx_counter += 1 + # else: warning already issued in prepare_latents + + noise = randn_tensor( + full_conditioning_latents.shape, generator=generator, device=device, dtype=prompt_dtype + ) + # Noise the 'clean' conditioning latents appropriate for this timestep t + noised_conditioning_latents_full = self.scheduler.add_noise(full_conditioning_latents, noise, t) + + # Combine current latents with noised conditioning latents using the mask + # latent_mask is True for generated regions, False for conditioned regions + model_input = torch.where(latent_mask, latents, noised_conditioning_latents_full) + else: + model_input = latents - # Scale model input + # Expand for CFG + latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # Predict the noise residual using Diffusion Forcing - # Use standard forward pass; forcing logic is applied outside the model - noise_pred = self.transformer( + # Predict noise + # Note: Transformer sees the combined input (noise in generated areas, noised known in conditioned areas) + model_pred = self.transformer( latent_model_input, - t, + timestep=t, encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=None, cross_attention_kwargs=cross_attention_kwargs, ).sample - # Perform guidance + # CFG if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # Update latents with the scheduler step - latents_input = latents - latents_updated = self.scheduler.step(noise_pred, t, latents_input).prev_sample - - # Apply forcing: use original latents for conditioning frames, updated latents for frames to generate - # forcing_mask is 1 for conditioning frames, 0 for frames to generate - latents = torch.where(forcing_mask, latents_input, latents_updated) + model_pred_uncond, model_pred_text = model_pred.chunk(2) + model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) + + # Scheduler step (operates on the full latents) + step_output = self.scheduler.step(model_pred, t, latents) + current_latents = step_output.prev_sample + + # Re-apply known conditioning information using the mask + # Ensures the conditioned areas stay consistent with their noised versions + if has_conditioning: + # Use the same noised_conditioning_latents_full calculated for timestep t + latents = torch.where(latent_mask, current_latents, noised_conditioning_latents_full) + else: + latents = current_latents - # Call the callback, if provided + # Callback if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) - # 9. Post-processing: decode latents - video = self.decode_latents(latents) + # 8. Post-processing + video_tensor = self._decode_latents(latents) + # video_tensor shape should be (batch, frames, channels, height, width) float [0,1] - # 10. Convert output format - if output_type == "np": - video = video.cpu().numpy() - elif output_type == "tensor": - video = video.cpu() + # Use VideoProcessor for standard output formats + video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) - # 11. Offload all models - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() + self.maybe_free_model_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index e1f8aff98b01..9b6dea857b58 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -14,11 +14,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from ...image_processor import VideoProcessor from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowUniPCMultistepScheduler from ...utils import ( @@ -32,21 +32,28 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> import PIL.Image >>> from diffusers import SkyReelsV2ImageToVideoPipeline + >>> from diffusers.utils import load_image, export_to_video + >>> # Load the pipeline >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... "HF_placeholder/SkyReels-V2-I2V-14B-540P", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") - >>> image = PIL.Image.open("input_image.jpg").convert("RGB") - >>> prompt = "A beautiful view of mountains" - >>> video_frames = pipe(prompt, image=image, num_frames=16).frames[0] + >>> # Load the conditioning image + >>> image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" # Example image + >>> image = load_image(image_url) + + >>> prompt = "A cat running across the grass" + >>> video_frames = pipe(prompt=prompt, image=image, num_frames=97).frames + >>> export_to_video(video_frames, "skyreels_v2_i2v.mp4") ``` """ @@ -58,39 +65,99 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline): This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). + The pipeline is based on the Wan 2.1 architecture (WanTransformer3DModel, AutoencoderKLWan). It uses a + `CLIPVisionModelWithProjection` to encode the conditioning image. It expects checkpoints saved in the standard + diffusers format, typically including subfolders: `vae`, `text_encoder`, `tokenizer`, `image_encoder`, + `transformer`, `scheduler`. + Args: vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + Variational Auto-Encoder (VAE) model capable of encoding and decoding videos in latent space. text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + Frozen text-encoder (e.g., CLIP). tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. + Tokenizer corresponding to the `text_encoder`. + image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): + Frozen image encoder (e.g., CLIP Vision Model) to encode the conditioning image. + image_processor ([`~transformers.CLIPImageProcessor`]): + Image processor corresponding to the `image_encoder`. transformer ([`WanTransformer3DModel`]): - A SkyReels-V2 transformer model for diffusion. + The core diffusion transformer model that denoises latents based on text and image conditioning. scheduler ([`FlowUniPCMultistepScheduler`]): - A scheduler to be used in combination with the transformer to denoise the encoded video latents. + A scheduler compatible with the Flow Matching framework used by SkyReels-V2. + video_processor ([`VideoProcessor`]): + Processor for converting VAE output latents to standard video formats (np, tensor, pil). """ - model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image_embeds"] def __init__( self, vae: AutoencoderKLWan, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, scheduler: FlowUniPCMultistepScheduler, + video_processor: VideoProcessor, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + image_encoder=image_encoder, + image_processor=image_processor, transformer=transformer, scheduler=scheduler, + video_processor=video_processor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # VaeImageProcessor is not needed here as CLIPImageProcessor handles image preprocessing. + + def _encode_image( + self, + image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]], + device: torch.device, + num_videos_per_prompt: int, + do_classifier_free_guidance: bool, + dtype: torch.dtype, + ) -> torch.Tensor: + """ + Encodes the input image using the image encoder. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `List[PIL.Image.Image]`): + Image or batch of images to encode. + device (`torch.device`): Target device. + num_videos_per_prompt (`int`): Number of videos per prompt (for repeating embeddings). + do_classifier_free_guidance (`bool`): Whether to generate negative embeddings. + dtype (`torch.dtype`): Target data type for embeddings. + + Returns: + `torch.Tensor`: Encoded image embeddings. + """ + if isinstance(image, PIL.Image.Image): + image = [image] # Processor expects a list + + # Preprocess image(s) + image_pixels = self.image_processor(image, return_tensors="pt").pixel_values + image_pixels = image_pixels.to(device=device, dtype=dtype) + + # Get image embeddings + image_embeds = self.image_encoder(image_pixels).image_embeds # [batch_size, seq_len, embed_dim] + + # Duplicate image embeddings for each generation per prompt + image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + + # Get negative embeddings for CFG + if do_classifier_free_guidance: + negative_image_embeds = torch.zeros_like(image_embeds) + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + + return image_embeds def enable_vae_slicing(self): r""" @@ -131,8 +198,9 @@ def encode_prompt( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, + lora_scale: Optional[float] = None, ): - r""" + """ Encodes the prompt into text encoder hidden states. Args: @@ -156,7 +224,12 @@ def encode_prompt( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. max_sequence_length (`int`, *optional*): Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + lora_scale (`float`, *optional*): + Scale for LoRA-based text embeddings. """ + # Set LoRA scale + lora_scale = lora_scale or self.lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -164,13 +237,10 @@ def encode_prompt( else: batch_size = prompt_embeds.shape[0] - # Define tokenizer parameters if max_sequence_length is None: max_sequence_length = self.tokenizer.model_max_length - # Get prompt text embeddings if prompt_embeds is None: - # Text encoder expects tokens to be of shape (batch_size, context_length) text_inputs = self.tokenizer( prompt, padding="max_length", @@ -187,112 +257,55 @@ def encode_prompt( ) prompt_embeds = prompt_embeds[0] - # Duplicate prompt embeddings for each generation per prompt - if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # Get negative prompt embeddings - if do_classifier_free_guidance and negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" - ) + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - negative_text_inputs = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - negative_input_ids = negative_text_inputs.input_ids - negative_attention_mask = negative_text_inputs.attention_mask - - negative_prompt_embeds = self.text_encoder( - negative_input_ids.to(device), - attention_mask=negative_attention_mask.to(device), - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - # Duplicate negative prompt embeddings for each generation per prompt - if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # For classifier-free guidance, combine embeddings if do_classifier_free_guidance: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = "" + if isinstance(negative_prompt, str) and negative_prompt == "": + negative_prompt = [negative_prompt] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + if isinstance(negative_prompt, list) and batch_size != len(negative_prompt): + raise ValueError("Negative prompt batch size mismatch") + uncond_input = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) + )[0] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return prompt_embeds - def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: """ Decode the generated latent sample using the VAE to produce video frames. Args: - latents (`torch.Tensor`): Generated latent samples from the diffusion process. + latents (`torch.Tensor`): + Generated latent samples of shape (batch, channels, latent_frames, height, width). Returns: - `torch.Tensor`: Decoded video frames. + `torch.Tensor`: Decoded video frames of shape (batch, frames, channels, height, width) as a float tensor in + range [0, 1]. """ - video_length = latents.shape[2] - - latents = 1 / self.vae.config.scaling_factor * latents # scale latents - - # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] - latents = latents.permute(0, 2, 1, 3, 4).reshape( - latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] - ) - - # Decode all frames + # AutoencoderKLWan expects B, C, F, H, W latents directly video = self.vae.decode(latents).sample - - # Reshape back to [batch, frames, channels, height, width] - video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) - - # Rescale video from [-1, 1] to [0, 1] + video = video.permute(0, 2, 1, 3, 4) # B, F, C, H, W video = (video / 2 + 0.5).clamp(0, 1) - - # Rescale to pixel values - video = (video * 255).to(torch.uint8) - - # Permute channels to [batch, frames, height, width, channels] - return video.permute(0, 1, 3, 4, 2) - - def encode_image(self, image: Union[torch.Tensor, PIL.Image.Image]) -> torch.Tensor: - """ - Encode the input image to latent space using VAE. - - Args: - image (`torch.Tensor` or `PIL.Image.Image`): Input image to encode. - - Returns: - `torch.Tensor`: Latent representation of the input image. - """ - if isinstance(image, PIL.Image.Image): - # Convert PIL image to tensor - image = torch.from_numpy(np.array(image)).float() / 127.5 - 1.0 - image = image.permute(2, 0, 1).unsqueeze(0) - elif isinstance(image, torch.Tensor) and image.ndim == 3: - # Add batch dimension for single image tensor - image = image.unsqueeze(0) - elif isinstance(image, torch.Tensor) and image.ndim == 4: - # Ensure input is in -1 to 1 range - if image.min() >= 0 and image.max() <= 1: - image = 2.0 * image - 1.0 - else: - raise ValueError(f"Invalid image input type: {type(image)}") - - image = image.to(device=self._execution_device, dtype=self.vae.dtype) - - # Encode the image using VAE - latents = self.vae.encode(image).latent_dist.sample() - latents = latents * self.vae.config.scaling_factor - - return latents + return video def prepare_latents( self, @@ -305,86 +318,95 @@ def prepare_latents( device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - image_latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Prepare latent variables from noise for the diffusion process, optionally incorporating image conditioning. + Prepare latent variables from noise for the diffusion process. Args: - batch_size (`int`): Number of samples to generate. - num_channels_latents (`int`): Number of channels in the latent space. - num_frames (`int`): Number of video frames to generate. - height (`int`): Height of the generated images in pixels. - width (`int`): Width of the generated images in pixels. - dtype (`torch.dtype`): Data type of the latent variables. - device (`torch.device`): Device to generate the latents on. - generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. - latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. - image_latents (`torch.Tensor`, *optional*): Latent representation of the conditioning image. + batch_size (`int`): + Number of samples to generate. + num_channels_latents (`int`): + Number of channels in the latent space. + num_frames (`int`): + Number of video frames to generate. + height (`int`): + Height of the generated video in pixels. + width (`int`): + Width of the generated video in pixels. + dtype (`torch.dtype`): + Data type of the latent variables. + device (`torch.device`): + Device to generate the latents on. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a random + noisy latent is generated. Returns: `torch.Tensor`: Prepared initial latent variables. """ - shape = ( + vae_scale_factor = self.vae_scale_factor + shape_spatial = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor) + shape_spatial = ( batch_size, num_channels_latents, - num_frames, height // self.vae_scale_factor, width // self.vae_scale_factor, ) + if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) + temporal_downsample_factor = 2**num_true_temporal_downsamples + else: + temporal_downsample_factor = 4 + logger.warning( + "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." + ) + + num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 + shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( - f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) - # Scale initial noise by the standard deviation latents = latents * self.scheduler.init_noise_sigma - - # If we have image conditioning, incorporate it into the first frame - if image_latents is not None: - # Expand image latents to match the number of frames by repeating along frame dimension - # This helps provide a stronger image signal throughout the video - image_latents = image_latents.unsqueeze(2) - first_frame_latents = image_latents.expand(-1, -1, 1, -1, -1) - - # Create a stronger conditioning for the first frame - # This helps ensure the video starts with the input image - alpha = 0.8 # Higher alpha means stronger image conditioning - latents[:, :, 0:1] = alpha * first_frame_latents + (1 - alpha) * latents[:, :, 0:1] - return latents @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - num_frames: int = 16, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]]] = None, + num_frames: int = 97, height: Optional[int] = None, width: Optional[int] = None, - num_inference_steps: int = 25, - guidance_scale: float = 5.0, + num_inference_steps: int = 30, + guidance_scale: float = 6.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, - output_type: Optional[str] = "tensor", + output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = None, + custom_shift: Optional[float] = 8.0, ) -> Union[SkyReelsV2PipelineOutput, Tuple]: """ The call function to the pipeline for generation. @@ -394,7 +416,7 @@ def __call__( The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): The image to use as the starting point for the video generation. - num_frames (`int`, *optional*, defaults to 16): + num_frames (`int`, *optional*, defaults to 97): The number of video frames to generate. height (`int`, *optional*, defaults to None): The height in pixels of the generated video frames. If not provided, height is automatically determined @@ -402,10 +424,10 @@ def __call__( width (`int`, *optional*, defaults to None): The width in pixels of the generated video frames. If not provided, width is automatically determined from the model configuration. - num_inference_steps (`int`, *optional*, defaults to 25): + num_inference_steps (`int`, *optional*, defaults to 30): The number of denoising steps. More denoising steps usually lead to higher quality videos at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 5.0): + guidance_scale (`float`, *optional*, defaults to 6.0): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): @@ -428,7 +450,7 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. max_sequence_length (`int`, *optional*): Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. - output_type (`str`, *optional*, defaults to `"tensor"`): + output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or `numpy.array` output respectively. return_dict (`bool`, *optional*, defaults to `True`): @@ -474,6 +496,8 @@ def __call__( # 1. Check inputs self.check_inputs( prompt, + image, + height, num_frames, callback_steps, negative_prompt, @@ -489,14 +513,25 @@ def __call__( batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: + elif image is not None: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + elif isinstance(image, list) and all(isinstance(i, PIL.Image.Image) for i in image): + batch_size = len(image) + elif isinstance(image, torch.Tensor): + batch_size = image.shape[0] + else: + # Fallback or error if image type is not recognized for batch size inference + raise ValueError("Cannot determine batch size from the provided image type.") + elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] - device = self._execution_device + else: + raise ValueError("Either `prompt`, `image`, or `prompt_embeds` must be provided.") - # 3. Determine whether to apply classifier-free guidance + device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - # 4. Encode input prompt + # 3. Encode input prompt prompt_embeds = self.encode_prompt( prompt, device, @@ -508,20 +543,18 @@ def __call__( max_sequence_length=max_sequence_length, ) - # 5. Encode the input image - image_latents = self.encode_image(image) + # 4. Encode input image + if image is None: + # This case should ideally be caught by check_inputs or initial ValueError + raise ValueError("`image` is a required argument for SkyReelsV2ImageToVideoPipeline.") - # Duplicate image latents for each batch and prompt - if isinstance(image, PIL.Image.Image) or (isinstance(image, torch.Tensor) and image.ndim < 4): - # For a single image to be duplicated - image_latents = image_latents.repeat(batch_size * num_videos_per_prompt, 1, 1, 1) + image_embeds = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) - # 6. Prepare timesteps - timestep_shift = None if custom_shift is None else {"shift": custom_shift} - self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) timesteps = self.scheduler.timesteps - # 7. Prepare latent variables + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -529,59 +562,44 @@ def __call__( num_frames, height, width, - prompt_embeds.dtype, + prompt_embeds.dtype, # Use prompt_embeds.dtype, image_embeds could be different device, generator, - latents, - image_latents, + latents=latents, ) - # 8. Denoising loop + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Expand latents for classifier-free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # Scale model input latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # Predict the noise residual - noise_pred = self.transformer( + model_pred = self.transformer( latent_model_input, - t, + timestep=t, encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=None, + encoder_hidden_states_image=image_embeds, # Pass image_embeds here cross_attention_kwargs=cross_attention_kwargs, ).sample - # Perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_pred_uncond, model_pred_text = model_pred.chunk(2) + model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - # Update latents with the scheduler step - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + latents = self.scheduler.step(model_pred, t, latents).prev_sample - # Call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # 9. Post-processing: decode latents - video = self.decode_latents(latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) - # 10. Convert output format - if output_type == "np": - video = video.cpu().numpy() - elif output_type == "tensor": - video = video.cpu() + # 8. Post-processing + video_tensor = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) - # 11. Offload all models - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() + self.maybe_free_model_hooks() if not return_dict: return (video,) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py index 03dddccdd43b..1423d6f8ffe1 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py @@ -16,9 +16,11 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL.Image import torch from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VideoProcessor from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowUniPCMultistepScheduler from ...utils import ( @@ -32,19 +34,23 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import SkyReelsV2TextToVideoPipeline + >>> from diffusers.utils import export_to_video + >>> # Load the pipeline >>> pipe = SkyReelsV2TextToVideoPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 + ... "HF_placeholder/SkyReels-V2-T2V-14B-540P", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") >>> prompt = "A panda eating bamboo on a rock, 4k, detailed" - >>> video_frames = pipe(prompt, num_frames=16).frames[0] + >>> video_frames = pipe(prompt, num_frames=97).frames # Default num_frames is often higher for video + >>> export_to_video(video_frames, "skyreels_v2_t2v.mp4") ``` """ @@ -56,13 +62,14 @@ class SkyReelsV2PipelineOutput(BaseOutput): Args: frames (`List[np.ndarray]` or `torch.Tensor`): - List of video frames generated by the pipeline. If the pipeline is running on GPU, the output is a - `torch.Tensor` of shape `(n_output_videos, n_frames, height, width, num_channels)`. Otherwise, the output - is a list of numpy arrays:`(n_output_videos, n_frames, height, width, num_channels)` with values in [0, - 255]. + List of video frames generated by the pipeline. Format depends on `output_type` argument. `np.ndarray` list + is default. For `output_type="np"`: list of `np.ndarray` of shape `(num_frames, height, width, + num_channels)` with values in [0, 255]. For `output_type="tensor"`: `torch.Tensor` of shape `(batch_size, + num_frames, channels, height, width)` with values in [0, 1]. For `output_type="pil"`: list of + `PIL.Image.Image`. """ - frames: Union[List[np.ndarray], torch.Tensor] + frames: Union[List[np.ndarray], torch.Tensor, List[PIL.Image.Image]] class SkyReelsV2TextToVideoPipeline(DiffusionPipeline): @@ -72,17 +79,24 @@ class SkyReelsV2TextToVideoPipeline(DiffusionPipeline): This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). + The pipeline is based on the Wan 2.1 architecture (WanTransformer3DModel, AutoencoderKLWan). It expects checkpoints + saved in the standard diffusers format, typically including subfolders: `vae`, `text_encoder`, `tokenizer`, + `transformer`, `scheduler`. + Args: vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + Variational Auto-Encoder (VAE) model capable of encoding and decoding videos in latent space. Expected to + handle 3D inputs (temporal dimension). text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + Frozen text-encoder. SkyReels-V2 typically uses CLIP (e.g., `openai/clip-vit-large-patch14`). tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. + Tokenizer corresponding to the `text_encoder`. transformer ([`WanTransformer3DModel`]): - A SkyReels-V2 transformer model for diffusion. + The core diffusion transformer model that denoises latents based on text conditioning. scheduler ([`FlowUniPCMultistepScheduler`]): - A scheduler to be used in combination with the transformer to denoise the encoded video latents. + A scheduler compatible with the Flow Matching framework used by SkyReels-V2. + video_processor ([`VideoProcessor`]): + Processor for converting VAE output latents to standard video formats (np, tensor, pil). """ model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -95,6 +109,7 @@ def __init__( tokenizer: CLIPTokenizer, transformer: WanTransformer3DModel, scheduler: FlowUniPCMultistepScheduler, + video_processor: VideoProcessor, ): super().__init__() self.register_modules( @@ -103,6 +118,7 @@ def __init__( tokenizer=tokenizer, transformer=transformer, scheduler=scheduler, + video_processor=video_processor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) @@ -145,32 +161,34 @@ def encode_prompt( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): - The prompt or prompts to guide image generation. + The prompt or prompts to guide video generation. device: (`torch.device`): - The torch device to place the resulting embeddings on. + torch device. num_videos_per_prompt (`int`): - The number of videos that should be generated per prompt. + Number of videos to generate per prompt. do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance or not. + Whether to use classifier-free guidance. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if - `guidance_scale` is less than 1). + The negative prompt or prompts. Ignored if `do_classifier_free_guidance` is `False`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. + Pre-generated text embeddings. Higher priority than `prompt`. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + Pre-generated negative text embeddings. Higher priority than `negative_prompt`. max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + Maximum sequence length for tokenizer. Defaults to `self.tokenizer.model_max_length`. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # Set LoRA scale + lora_scale = lora_scale or self.lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -184,7 +202,6 @@ def encode_prompt( # Get prompt text embeddings if prompt_embeds is None: - # Text encoder expects tokens to be of shape (batch_size, context_length) text_inputs = self.tokenizer( prompt, padding="max_length", @@ -193,88 +210,80 @@ def encode_prompt( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - prompt_embeds = self.text_encoder( text_input_ids.to(device), - attention_mask=attention_mask.to(device), - ) - prompt_embeds = prompt_embeds[0] - - # Duplicate prompt embeddings for each generation per prompt - if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # Get negative prompt embeddings - if do_classifier_free_guidance and negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - elif batch_size != len(negative_prompt): - raise ValueError( - f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" - ) - - negative_text_inputs = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - negative_input_ids = negative_text_inputs.input_ids - negative_attention_mask = negative_text_inputs.attention_mask + attention_mask=text_inputs.attention_mask.to(device), + output_hidden_states=False, + )[0] - negative_prompt_embeds = self.text_encoder( - negative_input_ids.to(device), - attention_mask=negative_attention_mask.to(device), - ) - negative_prompt_embeds = negative_prompt_embeds[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - # Duplicate negative prompt embeddings for each generation per prompt - if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - # For classifier-free guidance, combine embeddings + # Get negative prompt embeddings if do_classifier_free_guidance: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = "" # Use empty string + # Do not repeat for multiple prompts if it is empty string + if isinstance(negative_prompt, str) and negative_prompt == "": + negative_prompt = [negative_prompt] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] # Already handled single string case + + if isinstance(negative_prompt, list) and batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches batch size of `prompt`." + ) + + uncond_input = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_input.attention_mask.to(device), + output_hidden_states=False, + )[0] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate negative embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: """ - Decode the generated latent sample using the VAE to produce video frames. + Decode video latents using the VAE. Args: - latents (`torch.Tensor`): Generated latent samples from the diffusion process. + latents (`torch.Tensor`): Latents of shape (batch, channels, latent_frames, height, width). Returns: - `torch.Tensor`: Decoded video frames. + `torch.Tensor`: Decoded video frames of shape (batch, frames, channels, height, width) as float tensor [0, + 1]. """ - video_length = latents.shape[2] - - latents = 1 / self.vae.config.scaling_factor * latents # scale latents - - # Reshape latents from [batch, channels, frames, height, width] to [batch*frames, channels, height, width] - latents = latents.permute(0, 2, 1, 3, 4).reshape( - latents.shape[0] * latents.shape[2], latents.shape[1], latents.shape[3], latents.shape[4] - ) - - # Decode all frames + # AutoencoderKLWan expects B, C, F, H, W latents directly video = self.vae.decode(latents).sample - # Reshape back to [batch, frames, channels, height, width] - video = video.reshape(-1, video_length, video.shape[1], video.shape[2], video.shape[3]) - - # Rescale video from [-1, 1] to [0, 1] + # Output is likely B, C, F, H, W in range [-1, 1] + # Convert to B, F, C, H, W and range [0, 1] + video = video.permute(0, 2, 1, 3, 4) # B, F, C, H, W video = (video / 2 + 0.5).clamp(0, 1) - - # Rescale to pixel values - video = (video * 255).to(torch.uint8) - - # Permute channels to [batch, frames, height, width, channels] - return video.permute(0, 1, 3, 4, 2) + return video def prepare_latents( self, @@ -292,40 +301,64 @@ def prepare_latents( Prepare latent variables from noise for the diffusion process. Args: - batch_size (`int`): Number of samples to generate. - num_channels_latents (`int`): Number of channels in the latent space. - num_frames (`int`): Number of video frames to generate. - height (`int`): Height of the generated images in pixels. - width (`int`): Width of the generated images in pixels. - dtype (`torch.dtype`): Data type of the latent variables. - device (`torch.device`): Device to generate the latents on. - generator (`torch.Generator` or List[`torch.Generator`], *optional*): One or a list of generators. - latents (`torch.Tensor`, *optional*): Pre-generated noisy latent variables. + batch_size (`int`): + Number of samples to generate. + num_channels_latents (`int`): + Number of channels in the latent space (e.g., `self.vae.config.latent_channels`). + num_frames (`int`): + Number of video frames *in the final output*. + height (`int`): + Height of the generated video in pixels. + width (`int`): + Width of the generated video in pixels. + dtype (`torch.dtype`): + Data type for the latents. + device (`torch.device`): + Device for the latents. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + PyTorch Generator object(s). + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. Returns: - `torch.Tensor`: Prepared initial latent variables. + `torch.Tensor`: Initial latent variables (noise) scaled by the scheduler's `init_noise_sigma`. """ - shape = ( + vae_scale_factor = self.vae_scale_factor + shape_spatial = ( batch_size, num_channels_latents, - num_frames, - height // self.vae_scale_factor, - width // self.vae_scale_factor, + height // vae_scale_factor, + width // vae_scale_factor, ) + # Calculate temporal downsampling factor from VAE config + if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) + temporal_downsample_factor = 2**num_true_temporal_downsamples + else: + temporal_downsample_factor = 4 # Default from original SkyReels + logger.warning( + "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." + ) + + # Calculate the number of latent frames + num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 + shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( - f"Must provide a list of generators of length {batch_size}, but list of length {len(generator)} was provided." + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + raise ValueError(f"Unexpected latents shape: {latents.shape}. Expected {shape}.") latents = latents.to(device) - # Scale initial noise by the standard deviation + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -333,11 +366,11 @@ def prepare_latents( def __call__( self, prompt: Union[str, List[str]] = None, - num_frames: int = 16, + num_frames: int = 97, height: Optional[int] = None, width: Optional[int] = None, - num_inference_steps: int = 25, - guidance_scale: float = 5.0, + num_inference_steps: int = 30, + guidance_scale: float = 6.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -345,102 +378,87 @@ def __call__( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: Optional[int] = None, - output_type: Optional[str] = "tensor", + output_type: Optional[str] = "np", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = None, + custom_shift: Optional[float] = 8.0, ) -> Union[SkyReelsV2PipelineOutput, Tuple]: """ - The call function to the pipeline for generation. + The call function to the pipeline for text-to-video generation. Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - num_frames (`int`, *optional*, defaults to 16): - The number of video frames to generate. - height (`int`, *optional*, defaults to None): - The height in pixels of the generated video frames. If not provided, height is automatically determined - from the model configuration. - width (`int`, *optional*, defaults to None): - The width in pixels of the generated video frames. If not provided, width is automatically determined - from the model configuration. - num_inference_steps (`int`, *optional*, defaults to 25): - The number of denoising steps. More denoising steps usually lead to a higher quality videos at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 5.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + The prompt(s) to guide video generation. + num_frames (`int`, *optional*, defaults to 97): + The number of frames to generate. + height (`int`, *optional*): + The height in pixels of the generated video. Defaults to VAE output size. + width (`int`, *optional*): + The width in pixels of the generated video. Defaults to VAE output size. + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. + guidance_scale (`float`, *optional*, defaults to 6.0): + Scale for classifier-free guidance. `guidance_scale <= 1` disables CFG. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + The negative prompt(s) for CFG. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. + Number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + Generator(s) for deterministic generation. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. + Pre-generated noisy latents. Shape should match `(batch_size * num_videos_per_prompt, C, F_latent, + H_latent, W_latent)`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. + Pre-generated text embeddings. Shape `(batch_size * num_videos_per_prompt, seq_len, embed_dim)`. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + Pre-generated negative text embeddings. Shape `(batch_size * num_videos_per_prompt, seq_len, + embed_dim)`. max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. - output_type (`str`, *optional*, defaults to `"tensor"`): - The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or - `numpy.array` output respectively. + Maximum sequence length for tokenizer. Defaults to `self.tokenizer.model_max_length`. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. Choose between `"np"` (list of np.ndarray), `"tensor"` + (torch.Tensor), or `"pil"` (list of PIL.Image.Image). return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. + Whether to return a [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`] or a tuple. callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + A function called every `callback_steps` steps during inference. callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. + Frequency of callback calls. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - custom_shift (`float`, *optional*): - Custom shifting factor to use in the flow matching framework. + Arguments passed to the attention processor. + custom_shift (`float`, *optional*, defaults to 8.0): + The "shift" parameter for the `FlowUniPCMultistepScheduler`. Controls emphasis on diffusion trajectory + parts. Corresponds to `shift` in the original SkyReels repository. Examples: ```py - >>> import torch - >>> from diffusers import SkyReelsV2TextToVideoPipeline - - >>> pipe = SkyReelsV2TextToVideoPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "A panda eating bamboo on a rock, 4k, detailed" - >>> video_frames = pipe(prompt, num_frames=16).frames[0] + >>> # Example usage is included in the EXAMPLE_DOC_STRING variable ``` Returns: - [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is - returned where the first element is a list with the generated frames. + [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, returns [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`]. Otherwise + returns a tuple `(video,)` where `video` format depends on `output_type`. """ - # 0. Default height and width to transformer dimensions - height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 - width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 + # 0. Default height and width to VAE constraints + # Note: WanTransformer3DModel config doesn't have a standard 'sample_size'. + # Relying on VAE scale factor might be sufficient if input H/W are provided or inferred. + # Let's keep the defaults based on user input or raise error if not determinable. + if height is None or width is None: + # Height and width are required for this pipeline. + raise ValueError("Please provide `height` and `width` for video generation.") + + # Ensure height and width are multiples of VAE scale factor + height = height - height % self.vae_scale_factor + width = width - width % self.vae_scale_factor + if height == 0 or width == 0: + raise ValueError("Provided height and width are too small.") # 1. Check inputs self.check_inputs( - prompt, - num_frames, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) # 2. Define call parameters @@ -450,12 +468,13 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - device = self._execution_device - # 3. Determine whether to apply classifier-free guidance + device = self._execution_device + # Implement LoRA scale handling - requires PeftAdapterMixin setup if LoRA is used + lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None do_classifier_free_guidance = guidance_scale > 1.0 - # 4. Encode input prompt + # 3. Encode input prompt prompt_embeds = self.encode_prompt( prompt, device, @@ -465,14 +484,14 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) - # 5. Prepare timesteps - timestep_shift = None if custom_shift is None else {"shift": custom_shift} - self.scheduler.set_timesteps(num_inference_steps, device=device, **timestep_shift if timestep_shift else {}) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) timesteps = self.scheduler.timesteps - # 6. Prepare latent variables + # 5. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, @@ -483,53 +502,48 @@ def __call__( prompt_embeds.dtype, device, generator, - latents, + latents=latents, ) - # 7. Denoising loop + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - # Expand latents for classifier-free guidance + # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - - # Scale model input latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - # Predict the noise residual - noise_pred = self.transformer( + # predict the noise residual + model_pred = self.transformer( latent_model_input, - t, + timestep=t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, ).sample - # Perform guidance + # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + model_pred_uncond, model_pred_text = model_pred.chunk(2) + model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - # Update latents with the scheduler step - latents = self.scheduler.step(noise_pred, t, latents).prev_sample + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(model_pred, t, latents).prev_sample - # Call the callback, if provided + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) - # 8. Post-processing: decode latents - video = self.decode_latents(latents) + # 7. Post-processing + video_tensor = self.decode_latents(latents) # B, F, C, H, W float [0,1] - # 9. Convert output format - if output_type == "np": - video = video.cpu().numpy() - elif output_type == "tensor": - video = video.cpu() + # 8. Process video output + video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) - # 10. Offload all models - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() + # Offload all models + self.maybe_free_model_hooks() if not return_dict: return (video,) From 37ca14f62d832d7f409858cb518f78f7039e7aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 8 May 2025 13:17:00 +0300 Subject: [PATCH 004/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 41 ++++++++++++++----- .../pipeline_skyreels_v2_image_to_video.py | 4 +- .../pipeline_skyreels_v2_text_to_video.py | 2 +- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index b9c4664ca9a2..0e3fcce08e5e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -367,13 +367,22 @@ def prepare_latents_with_forcing( ) # Calculate temporal downsampling factor - if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: + if hasattr(self.vae.config, "temporal_downsample") and self.vae.config.temporal_downsample is not None: + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temporal_downsample if td) + temporal_downsample_factor = 2**num_true_temporal_downsamples + elif hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: + # This case handles old configs with the typo + logger.warning( + "Warning: VAE config uses a misspelled attribute 'temperal_downsample'. " + "Please update the VAE config to use 'temporal_downsample'. " + "Proceeding with the misspelled attribute for now." + ) num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) temporal_downsample_factor = 2**num_true_temporal_downsamples else: temporal_downsample_factor = 4 logger.warning( - "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." + "VAE config does not specify 'temporal_downsample'. Using default temporal_downsample_factor=4." ) # Calculate number of latent frames required for the full output sequence @@ -559,9 +568,14 @@ def __call__( Returns: [`~pipelines.skyreels_v2.pipeline_skyreels_v2_text_to_video.SkyReelsV2PipelineOutput`] or `tuple`. """ - # 0. Default height and width - height = height or self.transformer.config.sample_size * self.vae_scale_factor - width = width or self.transformer.config.sample_size * self.vae_scale_factor + # 0. Require height and width + if height is None or width is None: + raise ValueError("Please provide `height` and `width` for video generation.") + # Ensure multiples of VAE scale factor + height = height - height % self.vae_scale_factor + width = width - width % self.vae_scale_factor + if height == 0 or width == 0: + raise ValueError("Provided height and width are too small.") # 1. Check inputs self.check_inputs( @@ -638,6 +652,8 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # prepare a 1-element tensor for this timestep + timesteps_tensor = torch.tensor([t], device=device, dtype=torch.int64) # Prepare the known conditioned part (noised) noised_conditioning_latents_full = None if has_conditioning: @@ -656,8 +672,10 @@ def __call__( noise = randn_tensor( full_conditioning_latents.shape, generator=generator, device=device, dtype=prompt_dtype ) - # Noise the 'clean' conditioning latents appropriate for this timestep t - noised_conditioning_latents_full = self.scheduler.add_noise(full_conditioning_latents, noise, t) + # Noise the 'clean' conditioning latents appropriate for this timestep + noised_conditioning_latents_full = self.scheduler.add_noise( + full_conditioning_latents, noise, timesteps_tensor + ) # Combine current latents with noised conditioning latents using the mask # latent_mask is True for generated regions, False for conditioned regions @@ -667,16 +685,17 @@ def __call__( # Expand for CFG latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # Scale model input for this timestep + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps_tensor) # Predict noise # Note: Transformer sees the combined input (noise in generated areas, noised known in conditioned areas) model_pred = self.transformer( latent_model_input, - timestep=t, + timestep=timesteps_tensor, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=None, - cross_attention_kwargs=cross_attention_kwargs, + attention_kwargs=cross_attention_kwargs, ).sample # CFG @@ -685,7 +704,7 @@ def __call__( model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) # Scheduler step (operates on the full latents) - step_output = self.scheduler.step(model_pred, t, latents) + step_output = self.scheduler.step(model_pred, timesteps_tensor, latents) current_latents = step_output.prev_sample # Re-apply known conditioning information using the mask diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index 9b6dea857b58..f804d43e6928 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -579,8 +579,8 @@ def __call__( latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, # Pass image_embeds here - cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states_image=image_embeds, + attention_kwargs=cross_attention_kwargs, ).sample if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py index 1423d6f8ffe1..2bf9169308c1 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py @@ -518,7 +518,7 @@ def __call__( latent_model_input, timestep=t, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, + attention_kwargs=cross_attention_kwargs, ).sample # perform guidance From 95d0621526f9b43e4e1780d8663c862024f5a534 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 8 May 2025 20:01:54 +0300 Subject: [PATCH 005/264] 3rd draft --- .../pipeline_skyreels_v2_diffusion_forcing.py | 1153 +++++++++++++---- 1 file changed, 910 insertions(+), 243 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 0e3fcce08e5e..148b9c746d87 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -75,7 +75,12 @@ ... width=960, ... num_inference_steps=30, ... guidance_scale=6.0, - ... custom_shift=8.0, + ... shift=8.0, + ... # Parameters for long video generation / advanced forcing (optional) + ... # base_num_frames=97, + ... # ar_step=5, + ... # overlap_history=24, # Number of *frames* (not latent frames) for overlap + ... # addnoise_condition=0.0, ... ).frames >>> export_to_video(video, "skyreels_v2_df.mp4") ``` @@ -266,6 +271,7 @@ def encode_prompt( def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: # AutoencoderKLWan expects B, C, F, H, W latents directly video = self.vae.decode(latents).sample + # Permute from (B, C, F, H, W) to (B, F, C, H, W) for video_processor and standard video format video = video.permute(0, 2, 1, 3, 4) video = (video / 2 + 0.5).clamp(0, 1) return video @@ -319,132 +325,6 @@ def encode_frames(self, frames: Union[List[PIL.Image.Image], torch.Tensor]) -> t # Expected output shape: (batch, channels, latent_frames, latent_height, latent_width) return conditioning_latents - def prepare_latents_with_forcing( - self, - batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - conditioning_latents_sparse: Optional[torch.Tensor] = None, - conditioning_frame_mask: Optional[List[int]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, List[bool]]: - r""" - Prepare latent variables, incorporating conditioning frames based on the mask. - - Args: - batch_size (`int`): Number of samples to generate. - num_channels_latents (`int`): Number of channels in the latent space. - num_frames (`int`): Total number of video frames to generate. - height (`int`): Height of the generated video in pixels. - width (`int`): Width of the generated video in pixels. - dtype (`torch.dtype`): Data type of the latent variables. - device (`torch.device`): Device to generate the latents on. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): Generator(s) for noise. - latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. - conditioning_latents_sparse (`torch.Tensor`, *optional*): Latent representations of conditioning frames. - Shape: (batch, channels, num_cond_latent_frames, latent_H, latent_W). - conditioning_frame_mask (`List[int]`, *optional*): - Mask indicating which output frames are conditioned (1) or generated (0). Length must match - `num_frames`. - - Returns: - `Tuple[torch.Tensor, torch.Tensor, List[bool]]`: - - Prepared initial latent variables (noise). - - Mask tensor in latent space indicating regions to generate (True) vs conditioned (False). - - Boolean list representing the mask at the latent frame level (True=Conditioned). - """ - # Calculate latent spatial shape - shape_spatial = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - # Calculate temporal downsampling factor - if hasattr(self.vae.config, "temporal_downsample") and self.vae.config.temporal_downsample is not None: - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temporal_downsample if td) - temporal_downsample_factor = 2**num_true_temporal_downsamples - elif hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: - # This case handles old configs with the typo - logger.warning( - "Warning: VAE config uses a misspelled attribute 'temperal_downsample'. " - "Please update the VAE config to use 'temporal_downsample'. " - "Proceeding with the misspelled attribute for now." - ) - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) - temporal_downsample_factor = 2**num_true_temporal_downsamples - else: - temporal_downsample_factor = 4 - logger.warning( - "VAE config does not specify 'temporal_downsample'. Using default temporal_downsample_factor=4." - ) - - # Calculate number of latent frames required for the full output sequence - num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 - shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) - - # Create initial noise latents - if latents is None: - initial_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - initial_latents = latents.to(device) - - # Create latent mask - latent_mask_list_bool = [False] * num_latent_frames # Default: All False (generate) - if conditioning_latents_sparse is not None and conditioning_frame_mask is not None: - if len(conditioning_frame_mask) != num_frames: - raise ValueError("Length of conditioning_frame_mask must equal num_frames.") - - # Correct mapping from frame mask to latent frame mask - num_conditioned_latents_expected = 0 - for j in range(num_latent_frames): - start_frame_idx = j * temporal_downsample_factor - end_frame_idx = min(start_frame_idx + temporal_downsample_factor, num_frames) - # Check if any original frame corresponding to this latent frame is a conditioning frame (mask=1) - is_latent_conditioned = any( - conditioning_frame_mask[k] == 1 for k in range(start_frame_idx, end_frame_idx) - ) - latent_mask_list_bool[j] = ( - is_latent_conditioned # True if this latent frame corresponds to a conditioned frame - ) - if is_latent_conditioned: - num_conditioned_latents_expected += 1 - - # Validate the number of conditioning latents provided vs expected - if conditioning_latents_sparse.shape[2] != num_conditioned_latents_expected: - logger.warning( - f"Number of provided conditioning latents (frame dim: {conditioning_latents_sparse.shape[2]}) does not match " - f"the number of latent frames marked for conditioning ({num_conditioned_latents_expected}) based on the mask and stride. " - f"Ensure encode_frames provides latents only for the necessary frames." - ) - # This indicates a potential mismatch that could cause errors later. - - # Create the tensor mask for computations - # latent_mask_list_bool is True for conditioned frames - # We need a mask where True means GENERATE (inpaint area) - latent_mask_tensor_cond = torch.tensor( - latent_mask_list_bool, device=device, dtype=torch.bool - ) # True=Conditioned - latent_mask = ~latent_mask_tensor_cond.reshape(1, 1, num_latent_frames, 1, 1).expand_as( - initial_latents - ) # True = Generate - else: - # No conditioning, generate everything. Mask is all True (generate). - latent_mask = torch.ones_like(initial_latents, dtype=torch.bool) - latent_mask_list_bool = [False] * num_latent_frames # No frames are conditioned - - # Scale the initial noise by the standard deviation required by the scheduler - initial_latents = initial_latents * self.scheduler.init_noise_sigma - - # Return initial noise, mask (True=Generate), and boolean list (True=Conditioned) - return initial_latents, latent_mask, latent_mask_list_bool - def check_conditioning_inputs( self, conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor]], @@ -491,6 +371,274 @@ def check_conditioning_inputs( else: raise TypeError("`conditioning_frames` must be a List[PIL.Image.Image] or torch.Tensor.") + def _generate_timestep_matrix( + self, + num_latent_frames: int, + step_template: torch.Tensor, + base_latent_frames: int, + ar_step: int = 5, + num_latent_frames_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, # Not used in original SkyReels-V2 call, kept for completeness + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Tuple[int, int]]]: + """ + Generates the timestep matrix for autoregressive scheduling, adapted from SkyReels-V2. Operates on latent frame + counts. + """ + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 # num_inference_steps + 1 effectively + + # Ensure operations are on latent frames, assuming inputs are already latent frame counts + num_frames_block = num_latent_frames // causal_block_size + base_num_frames_block = base_latent_frames // causal_block_size + + if base_num_frames_block > 0 and base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + if gen_block > 0: + min_ar_step = infer_step_num / gen_block + if ar_step < min_ar_step: + logger.warning( + f"ar_step ({ar_step}) is less than the suggested minimum ({np.ceil(min_ar_step)}) " + f"for base_latent_frames={base_latent_frames} and num_inference_steps={infer_step_num}. " + "This might lead to suboptimal scheduling." + ) + else: + # Should not happen if base_num_frames_block is 0 and causal_block_size > 0 + logger.warning("base_num_frames_block is zero, ar_step check skipped.") + + # Add sentinel values to step_template for indexing logic + # Original SkyReels-V2 uses [999, ..., 0] + # self.scheduler.timesteps are typically [high, ..., low] + # We need to ensure indexing works correctly. + # The original logic `step_template[new_row]` implies new_row contains indices into step_template. + # `new_row` counts from 0 to num_iterations. Let's adjust `step_template` to be 0-indexed + # from num_iterations-1 down to 0. + # Example: if step_template is [980, 960 ... 20, 0]) + # The values in new_row are essentially "how many steps have been processed for this frame" + # from 0 (not started) to num_iterations (fully denoised). + # step_matrix.append(step_template[new_row]) -> This seems problematic if new_row is 0 to num_iterations. + # original: step_template = torch.cat([torch.tensor([999]), timesteps, torch.tensor([0])]) + # This padding makes step_template 1-indexed essentially. + # Let's use a direct mapping from "number of steps processed" to actual timestep value. + # If new_row[i] = k, it means frame i has undergone k denoising iterations. + # The corresponding timestep should be init_timesteps[k-1] if new_row is 1-indexed for steps. + # Original `pre_row` starts at 0. `new_row` increments. `new_row` goes from 0 to `num_iterations`. + # `step_template[new_row]` means `new_row` values are indices into a padded step_template. + # Let's use `step_template` (which are the actual timesteps from the scheduler) directly. + # if new_row[i] = k: use step_template[k-1] + # if new_row[i] = 0: this block is still pure noise / at initial state, use first timestep for processing. + # The original `step_matrix.append(step_template[new_row])` used a 1-indexed padded template. + # Our `new_row` is 0-indexed for states (0 to num_inference_steps). + # Timestep for state k (1 <= k <= num_inference_steps) is step_template[k-1]. + # Timestep for state 0 (initial) is step_template[0]. + # So, for a state `s` in `new_row` (0 to num_inference_steps), the timestep is `step_template[s.clamp(min=0, max=len(step_template)-1)]` + # No, simpler: if state is `k`, it means it has undergone `k` steps. The *next* step to apply is `step_template[k]`. + # So `new_row` (clamped 0 to `len(step_template)-1`) can directly index `step_template`. + # This gives the timestep *for the current operation*. + timesteps_for_matrix = step_template # These are the actual t values + # `new_row` will count how many steps a frame has been processed. Ranges 0 to `len(timesteps_for_matrix)`. + # 0 = initial noise state. `len(timesteps_for_matrix)` = fully processed by all timesteps. + # `num_iterations` here is `len(timesteps_for_matrix)`. + # Original `num_iterations = len(step_template) + 1`. + # Let's stick to original logic for `num_iterations` for `pre_row` and `new_row` counters. + # `num_iterations` = number of denoising *states* (0=initial noise, 1=after 1st step, ..., N=after Nth step) + # So, if N inference steps, there are N+1 states. `num_iterations = len(step_template) + 1`. + + pre_row = torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device) + if num_latent_frames_pre_ready > 0: + # Ensure pre_ready frames are marked as fully processed through all steps. + pre_row[: num_latent_frames_pre_ready // causal_block_size] = ( + num_iterations - 1 + ) # Mark as if processed by all steps + + # The loop condition `torch.all(pre_row >= (num_iterations - 1))` means loop until all blocks are fully processed. + while not torch.all(pre_row >= (num_iterations - 1)): + new_row = torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= (num_iterations - 1): # first block or previous block is fully denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = torch.clamp( + new_row, 0, num_iterations - 1 + ) # Clamp to valid state indices (0 to num_inference_steps) + + current_update_mask = (new_row != pre_row) & ( + new_row != (num_iterations - 1) + ) # Original: & (new_row != num_iterations) + # If new_row == num_iterations-1, it means it just reached the final denoised state. It *should* be updated. + # Let's use original: (new_row != pre_row) & (new_row < (num_iterations -1)) + # A frame is updated if its state changes AND it's not yet in the "fully processed" state. + # The original logic: update_mask.append((new_row != pre_row) & (new_row != num_iterations)) + # This seems to imply that even the step *to* num_iterations is not in update_mask. + # Let's stick to the original: + # update_mask is True if state changes AND it is not yet at the state corresponding to the last timestep. + # However, new_row is clamped to num_iterations-1 (max index for timesteps). + # So new_row == num_iterations will not happen here. + # Update if state changes AND it is not yet at the state corresponding to the last timestep. + current_update_mask = new_row != pre_row # True: need to update this frame at this stage + update_mask.append(current_update_mask) + + step_index.append(new_row.clone()) # Stores the "state index" for each block + + # Map state index (0 to N_steps) to actual timestep values. + # new_row values are 0 (initial noise) to N_steps (processed by last timestep). + # If new_row[j] = k: use timesteps_for_matrix[k-1] if k > 0. + # If new_row[j] = 0: this block is still pure noise / at initial state, use first timestep for processing. + # The original `step_matrix.append(step_template[new_row])` used a 1-indexed padded template. + # Our `new_row` is 0-indexed for states (0 to num_inference_steps). + # Timestep for state k (1 <= k <= num_inference_steps) is timesteps_for_matrix[k-1]. + # Timestep for state 0 (initial) is timesteps_for_matrix[0]. + # So, for a state `s` in `new_row` (0 to N_steps), the timestep is `timesteps_for_matrix[s.clamp(min=0, max=len(timesteps_for_matrix)-1)]` + # No, simpler: if state is `k`, it means it has undergone `k` steps. The *next* step to apply is `timesteps_for_matrix[k]`. + # So `new_row` (clamped 0 to `len(timesteps_for_matrix)-1`) can directly index `timesteps_for_matrix`. + # This gives the timestep *for the current operation*. + current_timesteps_for_blocks = timesteps_for_matrix[new_row.clamp(0, len(timesteps_for_matrix) - 1)] + step_matrix.append(current_timesteps_for_blocks) + + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block # Latent blocks + if shrink_interval_with_mask: # This was not used in original calls we saw + idx_sequence = torch.arange(num_frames_block, dtype=torch.long, device=step_template.device) + if update_mask: # Ensure update_mask is not empty + # Consider the update mask from the first iteration where meaningful updates happen + first_meaningful_update_mask = None + for um in update_mask: + if um.any(): + first_meaningful_update_mask = um + break + if first_meaningful_update_mask is not None: + update_mask_idx = idx_sequence[first_meaningful_update_mask] + if len(update_mask_idx) > 0: + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + + for curr_mask_row in update_mask: # Iterate over rows of update masks + # Original: if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + # Needs to check if terminal_flag is a valid index for curr_mask_row + if ( + terminal_flag < num_frames_block + and terminal_flag < len(curr_mask_row) + and curr_mask_row[terminal_flag] + ): + terminal_flag += 1 + # Ensure start of interval is not negative + current_interval_start = max(terminal_flag - base_num_frames_block, 0) + valid_interval.append( + (current_interval_start, terminal_flag) # These are in terms of blocks + ) + + if not step_matrix: # Handle case where loop doesn't run (e.g. num_latent_frames is 0) + # This case should ideally be caught earlier. + # Return empty tensors of appropriate shape if possible, or raise error. + # For now, let's assume num_latent_frames > 0. + # If num_frames_block is 0, then pre_row is empty, loop condition is true, returns empty lists. + # This needs robust handling if num_frames_block can be 0. + # Assuming num_frames_block > 0 from here. + if num_frames_block == 0: # If no blocks, then step_matrix etc will be empty + # Return empty tensors, but shapes need to be (0,0) or (0, num_latent_frames) if causal_block_size > 1 + # This edge case means num_latent_frames < causal_block_size + # The original code seems to assume num_latent_frames >= causal_block_size for block logic. + # Let's assume for now this means no processing needed for the matrix. + # The actual latents will be handled by the main loop. + # The matrix generation might not make sense. + # Let's return empty tensors that can be concatenated, or handle this in the caller. + # For now, if step_matrix is empty (e.g. num_frames_block=0), stack will fail. + # If num_frames_block is 0, then num_latent_frames < causal_block_size. + # The matrix logic might not apply. The caller should handle this. + # Or, we make it work for this by bypassing block logic. + # For now, assume num_frames_block > 0. + # The caller will ensure num_latent_frames is appropriate. + # If num_frames_block is 0, the while loop condition is met immediately, + # update_mask, step_index, step_matrix are empty. + # Stacking empty lists will raise an error. + # If step_matrix is empty, it implies no steps defined by matrix. + # This could happen if num_latent_frames_pre_ready covers all frames. + # Or if num_frames_block = 0. + + # If no iterations in while loop (e.g. all pre_row already >= num_iterations-1) + # this can happen if all frames are pre_ready. + # In this case, step_matrix will be empty. + # The caller needs to handle this (e.g., no denoising loop needed). + # For safety, if they are empty, create dummy tensors. + if not update_mask: + update_mask.append(torch.zeros(num_frames_block, dtype=torch.bool, device=step_template.device)) + if not step_index: + step_index.append(torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device)) + if not step_matrix: + step_matrix.append( + torch.zeros(num_frames_block, dtype=step_template.dtype, device=step_template.device) + ) + if not valid_interval: + valid_interval.append((0, 0)) + + step_update_mask_stacked = torch.stack(update_mask, dim=0) + step_index_stacked = torch.stack(step_index, dim=0) + step_matrix_stacked = torch.stack(step_matrix, dim=0) + + if causal_block_size > 1: + # Expand from blocks back to latent frames + step_update_mask_stacked = ( + step_update_mask_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + ) + step_index_stacked = ( + step_index_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + ) + step_matrix_stacked = ( + step_matrix_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + ) + + # Adjust valid_interval from block indices to latent frame indices + valid_interval_frames = [] + for s_block, e_block in valid_interval: + s_frame = s_block * causal_block_size + e_frame = e_block * causal_block_size + # Ensure the end frame does not exceed total latent frames + e_frame = min(e_frame, num_latent_frames) + valid_interval_frames.append((s_frame, e_frame)) + valid_interval = valid_interval_frames + else: # causal_block_size is 1, valid_interval is already in terms of latent frames + valid_interval_frames = [] + for s_idx, e_idx in valid_interval: + valid_interval_frames.append((s_idx, min(e_idx, num_latent_frames))) + valid_interval = valid_interval_frames + + # Ensure all returned tensors cover the full num_latent_frames if causal_block_size expansion happened + # This might be needed if num_latent_frames is not perfectly divisible by causal_block_size + # The original code implies that num_frames is handled by block logic. + # If num_latent_frames = 7, causal_block_size = 4. num_frames_block = 1. + # step_matrix_stacked would be (num_iterations_in_loop, 1, 4) -> (num_iterations_in_loop, 4) + # We need it to be (num_iterations_in_loop, 7). + # This flattening and repeating assumes num_frames_block * causal_block_size = num_latent_frames. + # This is only true if num_latent_frames is a multiple of causal_block_size. + # If not, the original code seems to truncate: `num_frames_block = num_frames // casual_block_size` + # The output matrices will then only cover `num_frames_block * causal_block_size` frames. + # This needs to be clarified or handled. For now, assume it covers up to num_latent_frames or truncates. + # The original code in `__call__` uses `latent_length` (which is num_latent_frames) for schedulers, + # but then `generate_timestep_matrix` is called with this `latent_length`. + # The `valid_interval` then slices these. + # It seems the matrix dimensions should align with `num_latent_frames`. + + # If causal_block_size > 1 and num_latent_frames is not a multiple: + if causal_block_size > 1 and step_matrix_stacked.shape[1] < num_latent_frames: + padding_size = num_latent_frames - step_matrix_stacked.shape[1] + # Pad with the values from the last valid frame/block + step_update_mask_stacked = torch.cat( + [step_update_mask_stacked, step_update_mask_stacked[:, -1:].repeat(1, padding_size)], dim=1 + ) + step_index_stacked = torch.cat( + [step_index_stacked, step_index_stacked[:, -1:].repeat(1, padding_size)], dim=1 + ) + step_matrix_stacked = torch.cat( + [step_matrix_stacked, step_matrix_stacked[:, -1:].repeat(1, padding_size)], dim=1 + ) + + return step_matrix_stacked, step_index_stacked, step_update_mask_stacked, valid_interval + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, @@ -514,7 +662,13 @@ def __call__( callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = 8.0, + fps: int = 24, # Add missing fps parameter + shift: Optional[float] = 8.0, + # New parameters for SkyReels-V2 original-style forcing and long video + base_num_frames: Optional[int] = None, # Max frames processed in one segment by transformer (pixel space) + ar_step: int = 5, + overlap_history: Optional[int] = None, + addnoise_condition: float = 0.0, ) -> Union[SkyReelsV2PipelineOutput, Tuple]: r""" Generate video frames conditioned on text prompts and optionally on specific input frames (diffusion forcing). @@ -562,42 +716,85 @@ def __call__( Frequency of callback calls. cross_attention_kwargs (`dict`, *optional*): Keyword arguments passed to the attention processor. - custom_shift (`float`, *optional*): - Shift parameter for the `FlowUniPCMultistepScheduler`. + fps (`int`, *optional*, defaults to 24): + Target frames per second for the video, passed to the transformer if supported. + shift (`float`, *optional*, defaults to 8.0): + Shift parameter for the `FlowUniPCMultistepScheduler` (if used as main scheduler). + base_num_frames (`int`, *optional*): + Maximum number of frames the transformer processes in a single segment when using original-style long + video generation. If None or if `num_frames` is less than this, processes `num_frames`. Corresponds to + `base_num_frames` in original SkyReels-V2 (pixel space). + ar_step (`int`, *optional*, defaults to 5): + Autoregressive step size used in `_generate_timestep_matrix` for scheduling timesteps across frames + within a segment. + overlap_history (`int`, *optional*): + Number of frames to overlap between segments for long video generation. If None, long video generation + with overlap is disabled. Uses pixel frame count. + addnoise_condition (`float`, *optional*, defaults to 0.0): + Controls the amount of noise added to conditioned latents (prefix or user-provided) during the + denoising loop when using original-style forcing. A value > 0 enables it. Corresponds to + `addnoise_condition` in SkyReels-V2. Returns: [`~pipelines.skyreels_v2.pipeline_skyreels_v2_text_to_video.SkyReelsV2PipelineOutput`] or `tuple`. """ - # 0. Require height and width + # 0. Require height and width & VAE spatial scale factor if height is None or width is None: raise ValueError("Please provide `height` and `width` for video generation.") - # Ensure multiples of VAE scale factor - height = height - height % self.vae_scale_factor - width = width - width % self.vae_scale_factor + vae_spatial_scale_factor = self.vae_scale_factor + height = height - height % vae_spatial_scale_factor + width = width - width % vae_spatial_scale_factor if height == 0 or width == 0: - raise ValueError("Provided height and width are too small.") + raise ValueError( + f"Provided height {height} and width {width} are too small. Must be divisible by {vae_spatial_scale_factor}." + ) + + # Determine VAE temporal downsample factor + if hasattr(self.vae.config, "temporal_downsample") and self.vae.config.temporal_downsample is not None: + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temporal_downsample if td) + vae_temporal_scale_factor = 2**num_true_temporal_downsamples + elif ( + hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None + ): # Typo in some old configs + num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) + vae_temporal_scale_factor = 2**num_true_temporal_downsamples + logger.warning("VAE config has misspelled 'temperal_downsample'. Using it.") + else: + vae_temporal_scale_factor = 4 # Default if not specified + logger.warning( + f"VAE config does not specify 'temporal_downsample'. Using default temporal_downsample_factor={vae_temporal_scale_factor}." + ) + + def to_latent_frames(pixel_frames): + if pixel_frames is None or pixel_frames <= 0: + return 0 + return (pixel_frames - 1) // vae_temporal_scale_factor + 1 + + num_latent_frames_total = to_latent_frames(num_frames) + if num_latent_frames_total <= 0: + raise ValueError( + f"num_frames {num_frames} results in {num_latent_frames_total} latent frames. Must be > 0." + ) + + # Determine causal_block_size for _generate_timestep_matrix from transformer config or default + causal_block_size = getattr(self.transformer.config, "causal_block_size", 1) + if not isinstance(causal_block_size, int) or causal_block_size <= 0: + causal_block_size = 1 # 1. Check inputs self.check_inputs( prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds ) - self.check_conditioning_inputs(conditioning_frames, conditioning_frame_mask, num_frames) - has_conditioning = conditioning_frames is not None + self.check_conditioning_inputs(conditioning_frames, conditioning_frame_mask, num_frames) # Will need review + has_initial_conditioning = conditioning_frames is not None # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - elif prompt_embeds is not None: - batch_size = prompt_embeds.shape[0] - elif isinstance(conditioning_frames, list) or isinstance(conditioning_frames, torch.Tensor): - batch_size = 1 # Assuming single batch item from frames for now - else: - raise ValueError("Cannot determine batch size.") - if has_conditioning and batch_size > 1: - logger.warning("Batch size > 1 not fully tested with diffusion forcing.") - batch_size = 1 + else: # prompt_embeds must be provided + batch_size = prompt_embeds.shape[0] // num_videos_per_prompt # Correct batch_size from prompt_embeds device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 @@ -609,129 +806,599 @@ def __call__( num_videos_per_prompt, do_classifier_free_guidance, negative_prompt, - prompt_embeds, - negative_prompt_embeds, - max_sequence_length, + prompt_embeds=prompt_embeds, # Pass through pre-generated embeds + negative_prompt_embeds=negative_prompt_embeds, # Pass through + max_sequence_length=max_sequence_length, ) - prompt_dtype = prompt_embeds.dtype - - # 4. Encode conditioning frames if provided - conditioning_latents_sparse = None - if has_conditioning: - conditioning_latents_sparse = self.encode_frames(conditioning_frames) - conditioning_latents_sparse = conditioning_latents_sparse.to(device=device, dtype=prompt_dtype) - # Repeat for num_videos_per_prompt - if conditioning_latents_sparse.shape[0] != batch_size * num_videos_per_prompt: - conditioning_latents_sparse = conditioning_latents_sparse.repeat_interleave( - num_videos_per_prompt, dim=0 - ) + effective_batch_size = batch_size * num_videos_per_prompt - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) + # 4. Prepare scheduler and timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - # 6. Prepare latent variables and mask - num_channels_latents = self.vae.config.latent_channels - # Pass conditioning_latents_sparse to prepare_latents only for validation checks if needed - latents, latent_mask, latent_mask_list_bool = self.prepare_latents_with_forcing( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height, - width, - prompt_dtype, - device, - generator, - latents=latents, - conditioning_latents_sparse=conditioning_latents_sparse, - conditioning_frame_mask=conditioning_frame_mask, - ) - # latents = initial noise; latent_mask = True means generate; latent_mask_list_bool = True means conditioned - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # prepare a 1-element tensor for this timestep - timesteps_tensor = torch.tensor([t], device=device, dtype=torch.int64) - # Prepare the known conditioned part (noised) - noised_conditioning_latents_full = None - if has_conditioning: - # Create a full-shaped tensor for the noised conditioning latents - full_conditioning_latents = torch.zeros_like(latents) - sparse_idx_counter = 0 - for latent_idx, is_conditioned in enumerate(latent_mask_list_bool): - if is_conditioned: # True means it *was* a conditioning frame - if sparse_idx_counter < conditioning_latents_sparse.shape[2]: - full_conditioning_latents[:, :, latent_idx, :, :] = conditioning_latents_sparse[ - :, :, sparse_idx_counter, :, : - ] - sparse_idx_counter += 1 - # else: warning already issued in prepare_latents - - noise = randn_tensor( - full_conditioning_latents.shape, generator=generator, device=device, dtype=prompt_dtype + # 5. Prepare initial conditioning information from user-provided frames for the ENTIRE video duration. + # This section prepares data structures that will be used by both short and long video paths + # to incorporate user-specified conditioning frames using the `addnoise_condition` logic. + + initial_clean_conditioning_latents = None + # Stores VAE encoded clean latents from user `conditioning_frames` at their respective + # positions in the full video timeline. Shape: (eff_batch_size, C, num_latent_frames_total, H_latent, W_latent) + + initial_conditioning_latent_mask = torch.zeros(num_latent_frames_total, dtype=torch.bool, device=device) + # Boolean mask indicating which *latent frames* along the total video duration are directly conditioned by user input. + # True if a latent frame `i` has user-provided conditioning data in `initial_clean_conditioning_latents`. + + num_latent_frames_pre_ready_from_user = 0 + # This specific variable counts how many *contiguous latent frames from the very beginning* of the video + # are to be considered "pre-ready" or "frozen" for the `_generate_timestep_matrix` function. + # For typical diffusion forcing where specific frames are conditioned (not necessarily a prefix), + # this will often be 0. True video prefixes would set this. + # All other user-specified conditioning (sparse, non-prefix) will be handled by `addnoise_condition` logic + # guided by `initial_conditioning_latent_mask` within the denoising loops. + + if has_initial_conditioning: + if conditioning_frame_mask is None: + raise ValueError("If conditioning_frames are provided, conditioning_frame_mask must also be provided.") + if len(conditioning_frame_mask) != num_frames: + raise ValueError( + f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must equal num_frames ({num_frames})." + ) + + # Encode the user-provided frames. self.encode_frames is expected to return appropriately batched latents. + # Assuming self.encode_frames returns: (effective_batch_size, C, N_sparse_encoded_latents, Hl, Wl) + # where N_sparse_encoded_latents matches the number of 1s in conditioning_frame_mask (potentially after VAE temporal compression). + sparse_user_latents = self.encode_frames(conditioning_frames) + + if sparse_user_latents.shape[0] != effective_batch_size: + if sparse_user_latents.shape[0] == 1 and effective_batch_size > 1: + sparse_user_latents = sparse_user_latents.repeat(effective_batch_size, 1, 1, 1, 1) + else: + raise ValueError( + f"Batch size mismatch: encoded conditioning frames have batch {sparse_user_latents.shape[0]}, expected {effective_batch_size}." ) - # Noise the 'clean' conditioning latents appropriate for this timestep - noised_conditioning_latents_full = self.scheduler.add_noise( - full_conditioning_latents, noise, timesteps_tensor + + latent_channels_for_cond = self.vae.config.latent_channels # Should match sparse_user_latents.shape[1] + latent_height_for_cond = sparse_user_latents.shape[-2] + latent_width_for_cond = sparse_user_latents.shape[-1] + + initial_clean_conditioning_latents = torch.zeros( + effective_batch_size, + latent_channels_for_cond, + num_latent_frames_total, + latent_height_for_cond, + latent_width_for_cond, + dtype=sparse_user_latents.dtype, + device=device, + ) + + processed_sparse_count = 0 + # Map the 1s in pixel-space `conditioning_frame_mask` to latent frame indices + # and place the corresponding `sparse_user_latents`. + for pixel_idx, is_cond_pixel in enumerate(conditioning_frame_mask): + if is_cond_pixel == 1: + if processed_sparse_count >= sparse_user_latents.shape[2]: + logger.warning( + f"More 1s in conditioning_frame_mask than available encoded conditioning frames ({sparse_user_latents.shape[2]})." + ) + break # Stop if we've run out of provided conditioning latents + + # Determine the target latent frame index for this pixel-space conditioned frame + # to_latent_frames expects 1-indexed pixel frame, returns 1-indexed latent frame count up to that point. + # So, for a pixel_idx (0-indexed), its corresponding latent frame index (0-indexed) is to_latent_frames(pixel_idx + 1) - 1. + target_latent_idx = to_latent_frames(pixel_idx + 1) - 1 + + if 0 <= target_latent_idx < num_latent_frames_total: + initial_clean_conditioning_latents[:, :, target_latent_idx, :, :] = sparse_user_latents[ + :, :, processed_sparse_count, :, : + ] + initial_conditioning_latent_mask[target_latent_idx] = True + processed_sparse_count += 1 + else: + logger.warning( + f"Pixel frame {pixel_idx} maps to latent index {target_latent_idx} out of bounds [0, {num_latent_frames_total - 1}]. Skipping." ) - # Combine current latents with noised conditioning latents using the mask - # latent_mask is True for generated regions, False for conditioned regions - model_input = torch.where(latent_mask, latents, noised_conditioning_latents_full) + if processed_sparse_count < sparse_user_latents.shape[2]: + logger.warning( + f"Only used {processed_sparse_count} out of {sparse_user_latents.shape[2]} provided conditioning latents. " + "Ensure conditioning_frame_mask aligns with the number of conditioning_frames provided and video length." + ) + + # For num_latent_frames_pre_ready_from_user: count contiguous conditioned frames from latent start. + # This is specifically for _generate_timestep_matrix's `num_pre_ready` which expects a prefix. + # Other sparse conditioning is handled by `addnoise_condition` using the mask and clean latents. + current_pre_ready_count = 0 + for i in range(num_latent_frames_total): + if initial_conditioning_latent_mask[i]: + current_pre_ready_count += 1 else: - model_input = latents + break + num_latent_frames_pre_ready_from_user = current_pre_ready_count + if num_latent_frames_pre_ready_from_user > 0: + logger.info( + f"{num_latent_frames_pre_ready_from_user} latent frames from the start are user-conditioned and will be treated as pre-ready." + ) - # Expand for CFG - latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input - # Scale model input for this timestep - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps_tensor) + # Latent Dims + num_channels_latents = self.vae.config.latent_channels + latent_height = height // vae_spatial_scale_factor + latent_width = width // vae_spatial_scale_factor + + # Determine if using long video path + use_long_video_path = False + base_latent_frames_seg = num_latent_frames_total # Default to full length if not long video mode + overlap_latent_frames_seg = 0 + + if overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames: + if base_num_frames <= 0: + raise ValueError("base_num_frames must be positive.") + if overlap_history < 0: + raise ValueError("overlap_history must be non-negative.") + if overlap_history >= base_num_frames: + raise ValueError("overlap_history must be < base_num_frames.") + + # Check if long video generation is actually needed after converting to latent frames + base_latent_frames_seg = to_latent_frames(base_num_frames) + overlap_latent_frames_seg = to_latent_frames(overlap_history) + + if base_latent_frames_seg <= 0: + base_latent_frames_seg = 1 # Ensure minimum segment length + # Overlap can be 0 in latent space even if > 0 in pixel space, handle this + if overlap_latent_frames_seg < 0: + overlap_latent_frames_seg = 0 # Should not happen with to_latent_frames but safety + + if num_latent_frames_total > base_latent_frames_seg: + use_long_video_path = True + if overlap_latent_frames_seg >= base_latent_frames_seg: + logger.warning( + f"Calculated overlap_latent_frames ({overlap_latent_frames_seg}) >= base_latent_frames_seg ({base_latent_frames_seg}). Disabling overlap for long video." + ) + overlap_latent_frames_seg = 0 - # Predict noise - # Note: Transformer sees the combined input (noise in generated areas, noised known in conditioned areas) + # Prepare initial latents for the full video (initial noise or user provided latents) + if latents is None: + shape = (effective_batch_size, num_channels_latents, num_latent_frames_total, latent_height, latent_width) + full_video_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + full_video_latents = full_video_latents * self.scheduler.init_noise_sigma + else: + expected_shape = ( + effective_batch_size, + num_channels_latents, + num_latent_frames_total, + latent_height, + latent_width, + ) + if latents.shape != expected_shape: + raise ValueError(f"Provided latents shape {latents.shape} does not match expected {expected_shape}.") + full_video_latents = latents.to(device, dtype=prompt_embeds.dtype) + + # Helper method for denoising a single segment + def _denoise_segment( + self, + segment_latents: torch.Tensor, # Latents for the current segment (slice of full_video_latents) + segment_start_global_idx: int, # Start index of this segment in the total video + num_latent_frames_this_segment: int, # Number of latent frames in this segment + num_pre_ready_for_this_segment: int, # Number of contiguous pre-ready frames at segment start + total_num_latent_frames: int, # Total latent frames in the whole video + initial_clean_conditioning_latents: Optional[torch.Tensor], # Clean conditioning for the whole video + initial_conditioning_latent_mask: Optional[torch.Tensor], # Mask for conditioned frames in the whole video + addnoise_condition: float, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + guidance_scale: float, + do_classifier_free_guidance: bool, + cross_attention_kwargs: Optional[Dict[str, Any]], + causal_block_size: int, + generator: Optional[Union[torch.Generator, List[torch.Generator]]], + progress_bar, + callback: Optional[Callable[[int, int, torch.Tensor], None]], + callback_steps: int, + fps: Optional[int] = None, # Add fps parameter + # Optional: segment index for logging + segment_index: int = 0, + num_segments: int = 1, + ) -> torch.Tensor: + # This method encapsulates the denoising loop logic previously in the short video path + # It will denoise `segment_latents` in place or return updated latents. + + # Generate the timestep matrix for this segment + step_matrix, step_index_matrix, update_mask_matrix, valid_interval_list = self._generate_timestep_matrix( + num_latent_frames=num_latent_frames_this_segment, + step_template=timesteps, + base_latent_frames=num_latent_frames_this_segment, # Base is segment length for matrix calc + ar_step=ar_step, + num_latent_frames_pre_ready=num_pre_ready_for_this_segment, + causal_block_size=causal_block_size, + ) + + if ( + not step_matrix.numel() + and num_latent_frames_this_segment > 0 + and num_pre_ready_for_this_segment < num_latent_frames_this_segment + ): + # Check if step_matrix is empty but should not be (i.e. not all frames are pre-ready) + logger.warning( + f"Segment {segment_index + 1}/{num_segments}: _generate_timestep_matrix returned empty." + ) + # If no steps, latents remain as is. + return segment_latents # Return unchanged if no denoising steps generated + + # Denoising loop for the current segment + # The progress bar total is managed by the main loop (either short path or long video outer loop) + for i_matrix_step in range(len(step_matrix)): + current_timesteps_for_frames = step_matrix[i_matrix_step] # Timestamps for each frame in the segment + current_update_mask_for_frames = update_mask_matrix[ + i_matrix_step + ] # Update mask for each frame in the segment + valid_interval_start_local, valid_interval_end_local = valid_interval_list[ + i_matrix_step + ] # Local indices within segment + + # Slice segment latents for the current valid processing window + latent_model_input = segment_latents[ + :, :, valid_interval_start_local:valid_interval_end_local + ].clone() # Clone for modification + # Timesteps for the transformer input - corresponds to the sliced latents + timestep_tensor_for_transformer = current_timesteps_for_frames[ + valid_interval_start_local:valid_interval_end_local + ] + + # === Implement addnoise_condition logic === + if addnoise_condition > 0.0 and initial_clean_conditioning_latents is not None: + # Iterate over frames within the current valid interval slice (local index j_local) + for j_local in range(valid_interval_end_local - valid_interval_start_local): + # Map local segment index to global video index + j_global = segment_start_global_idx + valid_interval_start_local + j_local + + # Check if this global frame is user-conditioned AND NOT considered pre-ready/frozen by _generate_timestep_matrix. + # _generate_timestep_matrix should handle num_pre_ready by setting update_mask=False for those frames. + # So we apply addnoise to frames marked as conditioned (globally) AND are being updated in this matrix step. + if ( + j_global < total_num_latent_frames + and initial_conditioning_latent_mask[j_global] + and current_update_mask_for_frames[valid_interval_start_local + j_local] + ): + # This is a conditioned frame that's being processed, apply addnoise logic + # Get the clean conditioned frame from the global conditioning tensor + clean_cond_frame = initial_clean_conditioning_latents[:, :, j_global, :, :] + # Original code used 0.001 * addnoise_condition for noise factor. Let's use param directly. + noise_factor = addnoise_condition + + # Add noise to the clean conditioned frame + noise = randn_tensor( + clean_cond_frame.shape, + generator=generator, + device=device, + dtype=clean_cond_frame.dtype, + ) + noised_cond_frame = clean_cond_frame * (1.0 - noise_factor) + noise * noise_factor + + # Replace the noisy latent in the model input slice with the noised conditioned frame + latent_model_input[:, :, j_local] = noised_cond_frame + + # Original code also clamped the timestep for conditioned frames. + # Let's clamp the specific frame's timestep in the transformer input tensor. + # Use addnoise_condition value as the clamping threshold. + if addnoise_condition > 0: # Avoid min with 0 if addnoise_condition is 0 + # Ensure addnoise_condition is treated as a valid timestep index or value. + # Assuming addnoise_condition is intended as a timestep value or something to clamp against. + # clamped_timestep_value = torch.tensor(addnoise_condition, device=device, dtype=timestep_tensor_for_transformer.dtype) + # Original clamped: `timestep_tensor_for_transformer[j_local] = torch.min(timestep_tensor_for_transformer[j_local], clamped_timestep_value)` + # Let's remove timestep clamping based on `addnoise_condition` for now, as it's less standard in Diffusers schedulers + # and its exact effect in original is tied to their scheduler step logic. + # addnoise_condition will primarily function as a noise amount control. + pass # Clamping logic removed + + # === End of addnoise_condition logic === + + # Model input for transformer (potentially with CFG duplication) + model_input_for_transformer = latent_model_input + if do_classifier_free_guidance: + model_input_for_transformer = torch.cat([model_input_for_transformer] * 2) + + # Timesteps for transformer (duplicated for CFG) + if do_classifier_free_guidance: + timestep_tensor_for_transformer_cfg = torch.cat([timestep_tensor_for_transformer] * 2) + else: + timestep_tensor_for_transformer_cfg = timestep_tensor_for_transformer + + # Transformer forward pass model_pred = self.transformer( - latent_model_input, - timestep=timesteps_tensor, + model_input_for_transformer, + timestep=timestep_tensor_for_transformer_cfg, # Use per-frame timesteps encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=None, - attention_kwargs=cross_attention_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + # Pass fps_embeds if fps is provided + fps=torch.tensor([fps] * model_input_for_transformer.shape[0], device=self.device) + if fps is not None + else None, ).sample - # CFG + # CFG guidance if do_classifier_free_guidance: model_pred_uncond, model_pred_text = model_pred.chunk(2) model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - # Scheduler step (operates on the full latents) - step_output = self.scheduler.step(model_pred, timesteps_tensor, latents) - current_latents = step_output.prev_sample + # Scheduler step per frame if updated + # Iterate over the frames in the current valid interval (local index idx_local_in_segment) + for idx_local_in_segment in range(model_pred.shape[2]): + # Global frame index corresponding to this local segment index + # g_idx = segment_start_global_idx + valid_interval_start_local + idx_local_in_segment # Not directly used in indexing below + + # Check the update mask *for the corresponding frame in the current segment* to see if it should be updated + # update_mask_matrix is (num_matrix_steps, num_latent_frames_this_segment) + # The update mask for the current frame is at `current_update_mask_for_frames[valid_interval_start_local + idx_local_in_segment]` + # which is equivalent to `current_update_mask_for_frames[idx_local_in_segment]` within the sliced valid interval + if current_update_mask_for_frames[valid_interval_start_local + idx_local_in_segment]: + frame_pred = model_pred[:, :, idx_local_in_segment] # Prediction for this frame (local index) + frame_latent = segment_latents[ + :, :, valid_interval_start_local + idx_local_in_segment + ] # Current latent for this frame (local in segment) + frame_timestep = current_timesteps_for_frames[ + valid_interval_start_local + idx_local_in_segment + ] # Timestep for this frame from matrix + + # Apply scheduler step for this single frame's latent + # Update the corresponding frame directly in the segment_latents tensor + segment_latents[:, :, valid_interval_start_local + idx_local_in_segment] = self.scheduler.step( + frame_pred, + frame_timestep, + frame_latent, + return_dict=False, + generator=generator, # Pass generator + )[0] + + # Progress bar update - handled by the outer loop caller + + # Callback - handled by the outer loop caller + + return segment_latents # Return the denoised segment latents + + # Main generation loop(s) + if not use_long_video_path: + logger.info(f"Short video path: {num_latent_frames_total} latent frames.") + # Denoise the full video (single segment) + denoised_latents = _denoise_segment( + self, + segment_latents=full_video_latents, # Denoise the whole video + segment_start_global_idx=0, # Starts at the beginning of the video + num_latent_frames_this_segment=num_latent_frames_total, # Segment is the whole video + num_pre_ready_for_this_segment=num_latent_frames_pre_ready_from_user, # Use user-provided pre-ready count + total_num_latent_frames=num_latent_frames_total, + initial_clean_conditioning_latents=initial_clean_conditioning_latents, + initial_conditioning_latent_mask=initial_conditioning_latent_mask, + addnoise_condition=addnoise_condition, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + cross_attention_kwargs=cross_attention_kwargs, + causal_block_size=causal_block_size, + generator=generator, + progress_bar=None, # Progress bar handled below the if block if needed, or manage here + callback=callback, + callback_steps=callback_steps, + segment_index=0, + num_segments=1, # For logging in helper + ) + # The progress bar and callback for the short video path need to be managed around the _denoise_segment call + # Or, pass the progress_bar and callback down and manage inside _denoise_segment. + # Let's manage the progress bar and callback *inside* _denoise_segment. + # Need to pass the progress_bar object to _denoise_segment. + # Refactor: pass progress_bar and callback objects to _denoise_segment. + + # Rerun _denoise_segment call with progress_bar and callback passed: + logger.info("Short video path: Starting denoising.") + with self.progress_bar(total=len(timesteps)) as progress_bar: + # Need to determine the actual number of matrix steps for progress bar total + # Call _generate_timestep_matrix first to get matrix size + temp_step_matrix, _, _, _ = self._generate_timestep_matrix( + num_latent_frames=num_latent_frames_total, + step_template=timesteps, + base_latent_frames=num_latent_frames_total, # For short path, base is the full length + ar_step=ar_step, + num_latent_frames_pre_ready=num_latent_frames_pre_ready_from_user, + causal_block_size=causal_block_size, + ) + progress_bar.total = ( + len(temp_step_matrix) if len(temp_step_matrix) > 0 else num_inference_steps + ) # Adjusted total + if ( + progress_bar.total == 0 + and num_latent_frames_total > 0 + and num_latent_frames_pre_ready_from_user < num_latent_frames_total + ): + logger.warning( + "Progress bar total is 0 but video needs denoising. Setting total to num_inference_steps." + ) + progress_bar.total = num_inference_steps + + denoised_latents = _denoise_segment( + self, + segment_latents=full_video_latents, # Denoise the whole video + segment_start_global_idx=0, # Starts at the beginning of the video + num_latent_frames_this_segment=num_latent_frames_total, # Segment is the whole video + num_pre_ready_for_this_segment=num_latent_frames_pre_ready_from_user, # Use user-provided pre-ready count + total_num_latent_frames=num_latent_frames_total, + initial_clean_conditioning_latents=initial_clean_conditioning_latents, + initial_conditioning_latent_mask=initial_conditioning_latent_mask, + addnoise_condition=addnoise_condition, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + cross_attention_kwargs=cross_attention_kwargs, + causal_block_size=causal_block_size, + generator=generator, + progress_bar=progress_bar, # Pass the segment progress bar + callback=callback, # Pass callback + callback_steps=callback_steps, + segment_index=0, + num_segments=1, + ) + + else: # Long video path - Implementation + logger.info( + f"Long video path: {num_latent_frames_total} total latents, {base_latent_frames_seg} base/segment, {overlap_latent_frames_seg} overlap." + ) + + # Calculate number of segments + non_overlapping_part_len = base_latent_frames_seg - overlap_latent_frames_seg + if non_overlapping_part_len <= 0: + logger.error("Non-overlapping part of segment is <=0. Adjust base_num_frames or overlap_history.") + raise ValueError("Non-overlapping part of segment must be positive.") - # Re-apply known conditioning information using the mask - # Ensures the conditioned areas stay consistent with their noised versions - if has_conditioning: - # Use the same noised_conditioning_latents_full calculated for timestep t - latents = torch.where(latent_mask, current_latents, noised_conditioning_latents_full) - else: - latents = current_latents + num_iterations = ( + 1 + + (num_latent_frames_total - base_latent_frames_seg + non_overlapping_part_len - 1) + // non_overlapping_part_len + ) + logger.info(f"Long video: processing in {num_iterations} segments.") + + # Initialize tensor to store the final denoised latents for the whole video + # final_denoised_latents = torch.zeros_like(full_video_latents) # This was unused + # Or accumulate in a list and concatenate later, might be better to avoid pre-allocation issues. + # Let's accumulate in a list for now. + final_denoised_segments_list = [] + + # Keep track of the last part of the previously denoised segment for overlap + previous_segment_denoised_overlap_latents = None # Will be (B, C, overlap_latent_frames_seg, H_l, W_l) + + # Main loop for processing segments + # Total progress bar should cover all matrix steps across all segments. + # This is hard to pre-calculate exactly due to matrix generation per segment. + # Alternative: progress bar tracks segments, and each segment denoising shows internal progress (if tqdm nested allowed). + # Let's track segments in the main progress bar for simplicity first. + + with self.progress_bar( + total=num_iterations, desc="Generating Long Video Segments" + ) as progress_bar_segments: + for i_segment in range(num_iterations): + # Determine the global latent frame indices for the current segment + segment_start_global_latent_idx = i_segment * non_overlapping_part_len + # The end index is start + base_segment_length, clamped to total length + segment_end_global_latent_idx = min( + segment_start_global_latent_idx + base_latent_frames_seg, num_latent_frames_total + ) - # Callback - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + # Adjust start index to include overlap from previous segment (if not the first segment) + current_segment_global_start_with_overlap = segment_start_global_latent_idx + if i_segment > 0: + current_segment_global_start_with_overlap -= overlap_latent_frames_seg - # 8. Post-processing - video_tensor = self._decode_latents(latents) - # video_tensor shape should be (batch, frames, channels, height, width) float [0,1] + # Determine the actual number of latent frames in this current segment + num_latent_frames_this_segment = ( + segment_end_global_latent_idx - current_segment_global_start_with_overlap + ) + + logger.info( + f" Processing segment {i_segment + 1}/{num_iterations} (global latent frames {current_segment_global_start_with_overlap} to {segment_end_global_latent_idx - 1})." + ) + + # Prepare latents for the current segment + # Start with the initial noise (or user provided) for this segment's range + current_segment_latents = full_video_latents[ + :, :, current_segment_global_start_with_overlap:segment_end_global_latent_idx + ].clone() - # Use VideoProcessor for standard output formats + # If there's overlap from the previous segment, overwrite the initial part with the denoised overlap + num_pre_ready_for_this_segment = ( + num_latent_frames_pre_ready_from_user # Start with user prefix count + ) + if i_segment > 0 and previous_segment_denoised_overlap_latents is not None: + # Overwrite the first `overlap_latent_frames_seg` of current_segment_latents + # with the denoised overlap from the previous segment. + # The number of pre-ready frames for the matrix generation should be the overlap length. + current_segment_latents[:, :, :overlap_latent_frames_seg] = ( + previous_segment_denoised_overlap_latents + ) + num_pre_ready_for_this_segment = ( + overlap_latent_frames_seg # Overlap serves as the frozen prefix for matrix generation + ) + logger.info( + f" Segment includes {overlap_latent_frames_seg} latent frames of overlap from previous segment." + ) + elif i_segment == 0 and num_latent_frames_pre_ready_from_user > 0: + # First segment, use user-provided prefix as pre-ready + num_pre_ready_for_this_segment = num_latent_frames_pre_ready_from_user + + # Denoise the current segment + # Pass the segment_latents to the helper function. + # The helper will operate on this tensor and return the denoised result for the segment. + denoised_segment_latents = _denoise_segment( + self, + segment_latents=current_segment_latents, # Latents for this segment + segment_start_global_idx=current_segment_global_start_with_overlap, # Global start index (including overlap) + num_latent_frames_this_segment=num_latent_frames_this_segment, # Length of this segment + num_pre_ready_for_this_segment=num_pre_ready_for_this_segment, # Pre-ready frames for matrix generation + total_num_latent_frames=num_latent_frames_total, + initial_clean_conditioning_latents=initial_clean_conditioning_latents, # Full video conditioning + initial_conditioning_latent_mask=initial_conditioning_latent_mask, # Full video mask + addnoise_condition=addnoise_condition, + timesteps=timesteps, + prompt_embeds=prompt_embeds, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + cross_attention_kwargs=cross_attention_kwargs, + causal_block_size=causal_block_size, + generator=generator, + progress_bar=progress_bar_segments, # Pass the segment progress bar + callback=callback, # Pass callback + callback_steps=callback_steps, + fps=fps, # Pass fps from __call__ scope + segment_index=i_segment, + num_segments=num_iterations, + ) + + # Extract the non-overlapping part of the denoised segment + # For the first segment, this is from the start up to base_latent_frames_seg. + # For subsequent segments, this is from overlap_latent_frames_seg onwards. + non_overlapping_segment_start_local_idx = 0 if i_segment == 0 else overlap_latent_frames_seg + non_overlapping_segment_latents = denoised_segment_latents[ + :, :, non_overlapping_segment_start_local_idx: + ].clone() + + # Add the non-overlapping part to the final list + final_denoised_segments_list.append(non_overlapping_segment_latents) + + # Prepare overlap for the next segment (if not the last segment) + if i_segment < num_iterations - 1: + # The overlap is the last `overlap_latent_frames_seg` of the *denoised* current segment. + overlap_start_local_idx = num_latent_frames_this_segment - overlap_latent_frames_seg + overlap_latents_to_process = denoised_segment_latents[:, :, overlap_start_local_idx:].clone() + + # Implementing original SkyReels V2 overlap handling (decode and re-encode): + # 1. Decode overlap latents to pixel space + # VAE decode expects (B, C, F, H, W). overlap_latents_to_process is (B, C, F_overlap, H_l, W_l) + decoded_overlap_pixels = self.vae.decode(overlap_latents_to_process).sample + # decoded_overlap_pixels is (B, F_overlap, C, H, W) after vae.decode (check vae_outputs) + + # 2. Re-encode pixel frames back to latent space + # VAE encode expects (B, C, F, H, W), so permute decoded pixels + # decoded_overlap_pixels needs permuting from (B, F_overlap, C, H, W) to (B, C, F_overlap, H, W) + encoded_overlap_latents = self.vae.encode( + decoded_overlap_pixels.permute(0, 2, 1, 3, 4) # Permute to (B, C, F, H, W) + ).latent_dist.sample() + + # Apply VAE scaling factor after encoding + previous_segment_denoised_overlap_latents = ( + encoded_overlap_latents * self.vae.config.scaling_factor + ) + + # Update segment progress bar + progress_bar_segments.update(1) # Update by 1 for each completed segment + + # Concatenate all denoised segments to get the full video latents + denoised_latents = torch.cat( + final_denoised_segments_list, dim=2 + ) # Concatenate along the frame dimension (dim=2) + + # 7. Post-processing (decode final denoised_latents) + video_tensor = self._decode_latents(denoised_latents) video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) self.maybe_free_model_hooks() if not return_dict: return (video,) - return SkyReelsV2PipelineOutput(frames=video) From 6f8a945678703d71fce9170974e8c500d2ff534c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 17:49:18 +0300 Subject: [PATCH 006/264] 4th draft --- .../pipelines/skyreels_v2/pipeline_output.py | 20 + .../skyreels_v2/pipeline_skyreels_v2.py | 595 ++++++ .../pipeline_skyreels_v2_diffusion_forcing.py | 1761 +++++------------ .../pipeline_skyreels_v2_image_to_video.py | 953 +++++---- .../pipeline_skyreels_v2_text_to_video.py | 551 ------ 5 files changed, 1700 insertions(+), 2180 deletions(-) create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_output.py create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py delete mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_output.py b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py new file mode 100644 index 000000000000..7a170d24c39a --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class SkyReelsV2PipelineOutput(BaseOutput): + r""" + Output class for SkyReelsV2 pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py new file mode 100644 index 000000000000..2563b16c2afc --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -0,0 +1,595 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...schedulers import FlowUniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers.utils import export_to_video + >>> from diffusers import AutoencoderKLWan, SkyReelsV2Pipeline + >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + + >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers + >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + >>> pipe = SkyReelsV2Pipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + >>> pipe.scheduler = FlowUniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=720, + ... width=1280, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class SkyReelsV2Pipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`WanTransformer3DModel`]): + Conditional Transformer to denoise the input latents. + scheduler ([`FlowUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowUniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + + Examples: + + Returns: + [`~SkyReelsOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 148b9c746d87..abb4e7cabd1b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -1,4 +1,4 @@ -# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,9 +17,13 @@ import numpy as np import PIL.Image import torch -from transformers import CLIPTextModel, CLIPTokenizer +import ftfy +import html +import re +from transformers import CLIPVisionModel, CLIPImageProcessor, AutoTokenizer, UMT5EncoderModel -from ...image_processor import VideoProcessor +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...video_processor import VideoProcessor from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowUniPCMultistepScheduler from ...utils import ( @@ -28,7 +32,10 @@ ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline -from .pipeline_skyreels_v2_text_to_video import SkyReelsV2PipelineOutput +from ...loaders import WanLoraLoaderMixin +from ...image_processor import PipelineImageInput +from .pipeline_output import SkyReelsV2PipelineOutput + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -48,1357 +55,667 @@ ... ) >>> pipe = pipe.to("cuda") - >>> # Prepare conditioning frames (list of PIL Images) - >>> # Example: Condition on frames 0, 24, 48, 72 for a 97-frame video - >>> frame_0 = load_image("./frame_0.png") # Placeholder paths - >>> frame_24 = load_image("./frame_24.png") - >>> frame_48 = load_image("./frame_48.png") - >>> frame_72 = load_image("./frame_72.png") - >>> conditioning_frames = [frame_0, frame_24, frame_48, frame_72] - - >>> # Create mask: 1 for conditioning frames, 0 for frames to generate - >>> num_frames = 97 # Match the default - >>> conditioning_frame_mask = [0] * num_frames - >>> # Example conditioning indices for a 97-frame video - >>> conditioning_indices = [0, 24, 48, 72] - >>> for idx in conditioning_indices: - ... if idx < num_frames: # Check bounds - ... conditioning_frame_mask[idx] = 1 - - >>> prompt = "A person walking in the park" - >>> video = pipe( - ... prompt=prompt, - ... conditioning_frames=conditioning_frames, - ... conditioning_frame_mask=conditioning_frame_mask, - ... num_frames=num_frames, - ... height=544, - ... width=960, - ... num_inference_steps=30, - ... guidance_scale=6.0, - ... shift=8.0, - ... # Parameters for long video generation / advanced forcing (optional) - ... # base_num_frames=97, - ... # ar_step=5, - ... # overlap_history=24, # Number of *frames* (not latent frames) for overlap - ... # addnoise_condition=0.0, - ... ).frames - >>> export_to_video(video, "skyreels_v2_df.mp4") + ... ... ``` """ -class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline): - """ - Pipeline for video generation with diffusion forcing (conditioning on specific frames) using SkyReels-V2. +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): + """ + Pipeline for video generation with diffusion forcing using SkyReels-V2. + This pipeline supports two main tasks: Text-to-Video (t2v) and Image-to-Video (i2v) + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). Args: - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. transformer ([`WanTransformer3DModel`]): - A SkyReels-V2 transformer model for diffusion with diffusion forcing capability. + Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowUniPCMultistepScheduler`]): - A scheduler to be used in combination with the transformer to denoise the encoded video latents. - video_processor ([`VideoProcessor`]): - Processor for post-processing generated videos (e.g., tensor to numpy array). + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, - vae: AutoencoderKLWan, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, scheduler: FlowUniPCMultistepScheduler, - video_processor: VideoProcessor, ): super().__init__() + self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, - video_processor=video_processor, + image_processor=image_processor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide image generation. - device: (`torch.device`): - The torch device to place the resulting embeddings on. - num_videos_per_prompt (`int`): - The number of videos that should be generated per prompt. - do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance or not. + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if - `guidance_scale` is less than 1). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - # Define tokenizer parameters - if max_sequence_length is None: - max_sequence_length = self.tokenizer.model_max_length - - # Get prompt text embeddings if prompt_embeds is None: - # Text encoder expects tokens to be of shape (batch_size, context_length) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask.to(device), + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, ) - prompt_embeds = prompt_embeds[0] - - # Duplicate prompt embeddings for each generation per prompt - if prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - prompt_embeds = prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - # Get negative prompt embeddings if do_classifier_free_guidance and negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) elif batch_size != len(negative_prompt): raise ValueError( - f"Batch size of `negative_prompt` should be {batch_size}, but is {len(negative_prompt)}" + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) - negative_text_inputs = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - negative_input_ids = negative_text_inputs.input_ids - negative_attention_mask = negative_text_inputs.attention_mask - - negative_prompt_embeds = self.text_encoder( - negative_input_ids.to(device), - attention_mask=negative_attention_mask.to(device), + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, ) - negative_prompt_embeds = negative_prompt_embeds[0] - - # Duplicate negative prompt embeddings for each generation per prompt - if negative_prompt_embeds.shape[0] < batch_size * num_videos_per_prompt: - negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_videos_per_prompt, dim=0) - - # For classifier-free guidance, combine embeddings - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - # AutoencoderKLWan expects B, C, F, H, W latents directly - video = self.vae.decode(latents).sample - # Permute from (B, C, F, H, W) to (B, F, C, H, W) for video_processor and standard video format - video = video.permute(0, 2, 1, 3, 4) - video = (video / 2 + 0.5).clamp(0, 1) - return video - def encode_frames(self, frames: Union[List[PIL.Image.Image], torch.Tensor]) -> torch.Tensor: - """ - Encodes conditioning frames into VAE latent space. - - Args: - frames (`List[PIL.Image.Image]` or `torch.Tensor`): - The conditioning frames (batch, frames, channels, height, width) or list of PIL images. Assumes frames - are already preprocessed (e.g., correct size, range [-1, 1] if tensor). + return prompt_embeds, negative_prompt_embeds - Returns: - `torch.Tensor`: Latent representations of the frames (batch, channels, latent_frames, height, width). - """ - if isinstance(frames, list): - # Assume list of PIL Images, needs preprocessing similar to VAE requirements - # Note: This uses a basic preprocessing, might need alignment with VaeImageProcessor - frames_np = np.stack([np.array(frame.convert("RGB")) for frame in frames]) - frames_tensor = torch.from_numpy(frames_np).float() / 127.5 - 1.0 # Range [-1, 1] - frames_tensor = frames_tensor.permute( - 0, 3, 1, 2 - ) # -> (batch*frames, channels, H, W) if flattened? No, needs batch dim. - # Let's assume the input list is for a SINGLE batch item's frames. - # Needs shape (batch=1, frames, channels, H, W) -> permute to (batch=1, channels, frames, H, W) - frames_tensor = frames_tensor.unsqueeze(0).permute(0, 2, 1, 3, 4) - elif isinstance(frames, torch.Tensor): - # Assume input tensor is already preprocessed and has shape (batch, frames, channels, H, W) or similar - # Ensure range [-1, 1] - if frames.min() >= 0.0 and frames.max() <= 1.0: - frames = 2.0 * frames - 1.0 - # Permute to (batch, channels, frames, H, W) - if frames.ndim == 5 and frames.shape[2] == 3: # Check if channels is dim 2 - frames_tensor = frames.permute(0, 2, 1, 3, 4) - elif frames.ndim == 5 and frames.shape[1] == 3: # Check if channels is dim 1 - frames_tensor = frames # Already in correct channel order - else: - raise ValueError("Input tensor shape not recognized. Expected (B, F, C, H, W) or (B, C, F, H, W).") - else: - raise TypeError("`conditioning_frames` must be a list of PIL Images or a torch Tensor.") - - frames_tensor = frames_tensor.to(device=self.device, dtype=self.vae.dtype) - - # Encode frames using VAE - # Note: VAE encode expects (batch, channels, frames, height, width)? Check AutoencoderKLWan docs/code - # AutoencoderKLWan._encode takes (B, C, F, H, W) - conditioning_latents = self.vae.encode(frames_tensor).latent_dist.sample() - conditioning_latents = conditioning_latents * self.vae.config.scaling_factor - - # Expected output shape: (batch, channels, latent_frames, latent_height, latent_width) - return conditioning_latents - - def check_conditioning_inputs( + def check_inputs( self, - conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor]], - conditioning_frame_mask: Optional[List[int]], - num_frames: int, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, ): - if conditioning_frames is None and conditioning_frame_mask is not None: - raise ValueError("`conditioning_frame_mask` provided without `conditioning_frames`.") - if conditioning_frames is not None and conditioning_frame_mask is None: - raise ValueError("`conditioning_frames` provided without `conditioning_frame_mask`.") - - if conditioning_frames is not None: - if not isinstance(conditioning_frame_mask, list) or not all( - isinstance(i, int) for i in conditioning_frame_mask - ): - raise TypeError("`conditioning_frame_mask` must be a list of integers (0 or 1).") - if len(conditioning_frame_mask) != num_frames: - raise ValueError( - f"`conditioning_frame_mask` length ({len(conditioning_frame_mask)}) must equal `num_frames` ({num_frames})." - ) - if not all(m in [0, 1] for m in conditioning_frame_mask): - raise ValueError("`conditioning_frame_mask` must only contain 0s and 1s.") - - num_masked_frames = sum(conditioning_frame_mask) - - if isinstance(conditioning_frames, list): - if not all(isinstance(f, PIL.Image.Image) for f in conditioning_frames): - raise TypeError("If `conditioning_frames` is a list, it must contain only PIL Images.") - if len(conditioning_frames) != num_masked_frames: - raise ValueError( - f"Number of `conditioning_frames` ({len(conditioning_frames)}) must equal the number of 1s in `conditioning_frame_mask` ({num_masked_frames})." - ) - elif isinstance(conditioning_frames, torch.Tensor): - # Assuming tensor shape is (num_masked_frames, C, H, W) or (B, num_masked_frames, C, H, W) etc. - # A simple check on the frame dimension assuming it's the first or second dim after batch - if not ( - conditioning_frames.shape[0] == num_masked_frames - or (conditioning_frames.ndim > 1 and conditioning_frames.shape[1] == num_masked_frames) - ): - # This check is basic and might need refinement based on expected tensor layout - logger.warning( - f"Number of frames in `conditioning_frames` tensor ({conditioning_frames.shape}) does not seem to match the number of 1s in `conditioning_frame_mask` ({num_masked_frames}). Ensure tensor shape is correct." - ) - else: - raise TypeError("`conditioning_frames` must be a List[PIL.Image.Image] or torch.Tensor.") - - def _generate_timestep_matrix( - self, - num_latent_frames: int, - step_template: torch.Tensor, - base_latent_frames: int, - ar_step: int = 5, - num_latent_frames_pre_ready: int = 0, - causal_block_size: int = 1, - shrink_interval_with_mask: bool = False, # Not used in original SkyReels-V2 call, kept for completeness - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Tuple[int, int]]]: - """ - Generates the timestep matrix for autoregressive scheduling, adapted from SkyReels-V2. Operates on latent frame - counts. - """ - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 # num_inference_steps + 1 effectively - - # Ensure operations are on latent frames, assuming inputs are already latent frame counts - num_frames_block = num_latent_frames // causal_block_size - base_num_frames_block = base_latent_frames // causal_block_size - - if base_num_frames_block > 0 and base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - if gen_block > 0: - min_ar_step = infer_step_num / gen_block - if ar_step < min_ar_step: - logger.warning( - f"ar_step ({ar_step}) is less than the suggested minimum ({np.ceil(min_ar_step)}) " - f"for base_latent_frames={base_latent_frames} and num_inference_steps={infer_step_num}. " - "This might lead to suboptimal scheduling." - ) - else: - # Should not happen if base_num_frames_block is 0 and causal_block_size > 0 - logger.warning("base_num_frames_block is zero, ar_step check skipped.") - - # Add sentinel values to step_template for indexing logic - # Original SkyReels-V2 uses [999, ..., 0] - # self.scheduler.timesteps are typically [high, ..., low] - # We need to ensure indexing works correctly. - # The original logic `step_template[new_row]` implies new_row contains indices into step_template. - # `new_row` counts from 0 to num_iterations. Let's adjust `step_template` to be 0-indexed - # from num_iterations-1 down to 0. - # Example: if step_template is [980, 960 ... 20, 0]) - # The values in new_row are essentially "how many steps have been processed for this frame" - # from 0 (not started) to num_iterations (fully denoised). - # step_matrix.append(step_template[new_row]) -> This seems problematic if new_row is 0 to num_iterations. - # original: step_template = torch.cat([torch.tensor([999]), timesteps, torch.tensor([0])]) - # This padding makes step_template 1-indexed essentially. - # Let's use a direct mapping from "number of steps processed" to actual timestep value. - # If new_row[i] = k, it means frame i has undergone k denoising iterations. - # The corresponding timestep should be init_timesteps[k-1] if new_row is 1-indexed for steps. - # Original `pre_row` starts at 0. `new_row` increments. `new_row` goes from 0 to `num_iterations`. - # `step_template[new_row]` means `new_row` values are indices into a padded step_template. - # Let's use `step_template` (which are the actual timesteps from the scheduler) directly. - # if new_row[i] = k: use step_template[k-1] - # if new_row[i] = 0: this block is still pure noise / at initial state, use first timestep for processing. - # The original `step_matrix.append(step_template[new_row])` used a 1-indexed padded template. - # Our `new_row` is 0-indexed for states (0 to num_inference_steps). - # Timestep for state k (1 <= k <= num_inference_steps) is step_template[k-1]. - # Timestep for state 0 (initial) is step_template[0]. - # So, for a state `s` in `new_row` (0 to num_inference_steps), the timestep is `step_template[s.clamp(min=0, max=len(step_template)-1)]` - # No, simpler: if state is `k`, it means it has undergone `k` steps. The *next* step to apply is `step_template[k]`. - # So `new_row` (clamped 0 to `len(step_template)-1`) can directly index `step_template`. - # This gives the timestep *for the current operation*. - timesteps_for_matrix = step_template # These are the actual t values - # `new_row` will count how many steps a frame has been processed. Ranges 0 to `len(timesteps_for_matrix)`. - # 0 = initial noise state. `len(timesteps_for_matrix)` = fully processed by all timesteps. - # `num_iterations` here is `len(timesteps_for_matrix)`. - # Original `num_iterations = len(step_template) + 1`. - # Let's stick to original logic for `num_iterations` for `pre_row` and `new_row` counters. - # `num_iterations` = number of denoising *states* (0=initial noise, 1=after 1st step, ..., N=after Nth step) - # So, if N inference steps, there are N+1 states. `num_iterations = len(step_template) + 1`. - - pre_row = torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device) - if num_latent_frames_pre_ready > 0: - # Ensure pre_ready frames are marked as fully processed through all steps. - pre_row[: num_latent_frames_pre_ready // causal_block_size] = ( - num_iterations - 1 - ) # Mark as if processed by all steps - - # The loop condition `torch.all(pre_row >= (num_iterations - 1))` means loop until all blocks are fully processed. - while not torch.all(pre_row >= (num_iterations - 1)): - new_row = torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= (num_iterations - 1): # first block or previous block is fully denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = torch.clamp( - new_row, 0, num_iterations - 1 - ) # Clamp to valid state indices (0 to num_inference_steps) - - current_update_mask = (new_row != pre_row) & ( - new_row != (num_iterations - 1) - ) # Original: & (new_row != num_iterations) - # If new_row == num_iterations-1, it means it just reached the final denoised state. It *should* be updated. - # Let's use original: (new_row != pre_row) & (new_row < (num_iterations -1)) - # A frame is updated if its state changes AND it's not yet in the "fully processed" state. - # The original logic: update_mask.append((new_row != pre_row) & (new_row != num_iterations)) - # This seems to imply that even the step *to* num_iterations is not in update_mask. - # Let's stick to the original: - # update_mask is True if state changes AND it is not yet at the state corresponding to the last timestep. - # However, new_row is clamped to num_iterations-1 (max index for timesteps). - # So new_row == num_iterations will not happen here. - # Update if state changes AND it is not yet at the state corresponding to the last timestep. - current_update_mask = new_row != pre_row # True: need to update this frame at this stage - update_mask.append(current_update_mask) - - step_index.append(new_row.clone()) # Stores the "state index" for each block - - # Map state index (0 to N_steps) to actual timestep values. - # new_row values are 0 (initial noise) to N_steps (processed by last timestep). - # If new_row[j] = k: use timesteps_for_matrix[k-1] if k > 0. - # If new_row[j] = 0: this block is still pure noise / at initial state, use first timestep for processing. - # The original `step_matrix.append(step_template[new_row])` used a 1-indexed padded template. - # Our `new_row` is 0-indexed for states (0 to num_inference_steps). - # Timestep for state k (1 <= k <= num_inference_steps) is timesteps_for_matrix[k-1]. - # Timestep for state 0 (initial) is timesteps_for_matrix[0]. - # So, for a state `s` in `new_row` (0 to N_steps), the timestep is `timesteps_for_matrix[s.clamp(min=0, max=len(timesteps_for_matrix)-1)]` - # No, simpler: if state is `k`, it means it has undergone `k` steps. The *next* step to apply is `timesteps_for_matrix[k]`. - # So `new_row` (clamped 0 to `len(timesteps_for_matrix)-1`) can directly index `timesteps_for_matrix`. - # This gives the timestep *for the current operation*. - current_timesteps_for_blocks = timesteps_for_matrix[new_row.clamp(0, len(timesteps_for_matrix) - 1)] - step_matrix.append(current_timesteps_for_blocks) - - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block # Latent blocks - if shrink_interval_with_mask: # This was not used in original calls we saw - idx_sequence = torch.arange(num_frames_block, dtype=torch.long, device=step_template.device) - if update_mask: # Ensure update_mask is not empty - # Consider the update mask from the first iteration where meaningful updates happen - first_meaningful_update_mask = None - for um in update_mask: - if um.any(): - first_meaningful_update_mask = um - break - if first_meaningful_update_mask is not None: - update_mask_idx = idx_sequence[first_meaningful_update_mask] - if len(update_mask_idx) > 0: - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - - for curr_mask_row in update_mask: # Iterate over rows of update masks - # Original: if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - # Needs to check if terminal_flag is a valid index for curr_mask_row - if ( - terminal_flag < num_frames_block - and terminal_flag < len(curr_mask_row) - and curr_mask_row[terminal_flag] - ): - terminal_flag += 1 - # Ensure start of interval is not negative - current_interval_start = max(terminal_flag - base_num_frames_block, 0) - valid_interval.append( - (current_interval_start, terminal_flag) # These are in terms of blocks + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if not step_matrix: # Handle case where loop doesn't run (e.g. num_latent_frames is 0) - # This case should ideally be caught earlier. - # Return empty tensors of appropriate shape if possible, or raise error. - # For now, let's assume num_latent_frames > 0. - # If num_frames_block is 0, then pre_row is empty, loop condition is true, returns empty lists. - # This needs robust handling if num_frames_block can be 0. - # Assuming num_frames_block > 0 from here. - if num_frames_block == 0: # If no blocks, then step_matrix etc will be empty - # Return empty tensors, but shapes need to be (0,0) or (0, num_latent_frames) if causal_block_size > 1 - # This edge case means num_latent_frames < causal_block_size - # The original code seems to assume num_latent_frames >= causal_block_size for block logic. - # Let's assume for now this means no processing needed for the matrix. - # The actual latents will be handled by the main loop. - # The matrix generation might not make sense. - # Let's return empty tensors that can be concatenated, or handle this in the caller. - # For now, if step_matrix is empty (e.g. num_frames_block=0), stack will fail. - # If num_frames_block is 0, then num_latent_frames < causal_block_size. - # The matrix logic might not apply. The caller should handle this. - # Or, we make it work for this by bypassing block logic. - # For now, assume num_frames_block > 0. - # The caller will ensure num_latent_frames is appropriate. - # If num_frames_block is 0, the while loop condition is met immediately, - # update_mask, step_index, step_matrix are empty. - # Stacking empty lists will raise an error. - # If step_matrix is empty, it implies no steps defined by matrix. - # This could happen if num_latent_frames_pre_ready covers all frames. - # Or if num_frames_block = 0. - - # If no iterations in while loop (e.g. all pre_row already >= num_iterations-1) - # this can happen if all frames are pre_ready. - # In this case, step_matrix will be empty. - # The caller needs to handle this (e.g., no denoising loop needed). - # For safety, if they are empty, create dummy tensors. - if not update_mask: - update_mask.append(torch.zeros(num_frames_block, dtype=torch.bool, device=step_template.device)) - if not step_index: - step_index.append(torch.zeros(num_frames_block, dtype=torch.long, device=step_template.device)) - if not step_matrix: - step_matrix.append( - torch.zeros(num_frames_block, dtype=step_template.dtype, device=step_template.device) - ) - if not valid_interval: - valid_interval.append((0, 0)) - - step_update_mask_stacked = torch.stack(update_mask, dim=0) - step_index_stacked = torch.stack(step_index, dim=0) - step_matrix_stacked = torch.stack(step_matrix, dim=0) - - if causal_block_size > 1: - # Expand from blocks back to latent frames - step_update_mask_stacked = ( - step_update_mask_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." ) - step_index_stacked = ( - step_index_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." ) - step_matrix_stacked = ( - step_matrix_stacked.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) - - # Adjust valid_interval from block indices to latent frame indices - valid_interval_frames = [] - for s_block, e_block in valid_interval: - s_frame = s_block * causal_block_size - e_frame = e_block * causal_block_size - # Ensure the end frame does not exceed total latent frames - e_frame = min(e_frame, num_latent_frames) - valid_interval_frames.append((s_frame, e_frame)) - valid_interval = valid_interval_frames - else: # causal_block_size is 1, valid_interval is already in terms of latent frames - valid_interval_frames = [] - for s_idx, e_idx in valid_interval: - valid_interval_frames.append((s_idx, min(e_idx, num_latent_frames))) - valid_interval = valid_interval_frames - - # Ensure all returned tensors cover the full num_latent_frames if causal_block_size expansion happened - # This might be needed if num_latent_frames is not perfectly divisible by causal_block_size - # The original code implies that num_frames is handled by block logic. - # If num_latent_frames = 7, causal_block_size = 4. num_frames_block = 1. - # step_matrix_stacked would be (num_iterations_in_loop, 1, 4) -> (num_iterations_in_loop, 4) - # We need it to be (num_iterations_in_loop, 7). - # This flattening and repeating assumes num_frames_block * causal_block_size = num_latent_frames. - # This is only true if num_latent_frames is a multiple of causal_block_size. - # If not, the original code seems to truncate: `num_frames_block = num_frames // casual_block_size` - # The output matrices will then only cover `num_frames_block * causal_block_size` frames. - # This needs to be clarified or handled. For now, assume it covers up to num_latent_frames or truncates. - # The original code in `__call__` uses `latent_length` (which is num_latent_frames) for schedulers, - # but then `generate_timestep_matrix` is called with this `latent_length`. - # The `valid_interval` then slices these. - # It seems the matrix dimensions should align with `num_latent_frames`. - - # If causal_block_size > 1 and num_latent_frames is not a multiple: - if causal_block_size > 1 and step_matrix_stacked.shape[1] < num_latent_frames: - padding_size = num_latent_frames - step_matrix_stacked.shape[1] - # Pad with the values from the last valid frame/block - step_update_mask_stacked = torch.cat( - [step_update_mask_stacked, step_update_mask_stacked[:, -1:].repeat(1, padding_size)], dim=1 + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - step_index_stacked = torch.cat( - [step_index_stacked, step_index_stacked[:, -1:].repeat(1, padding_size)], dim=1 + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 ) - step_matrix_stacked = torch.cat( - [step_matrix_stacked, step_matrix_stacked[:, -1:].repeat(1, padding_size)], dim=1 + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - return step_matrix_stacked, step_index_stacked, step_update_mask_stacked, valid_interval + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - conditioning_frames: Optional[Union[List[PIL.Image.Image], torch.Tensor]] = None, - conditioning_frame_mask: Optional[List[int]] = None, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: int = 480, + width: int = 832, num_frames: int = 97, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 30, - guidance_scale: float = 6.0, - negative_prompt: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - fps: int = 24, # Add missing fps parameter - shift: Optional[float] = 8.0, - # New parameters for SkyReels-V2 original-style forcing and long video - base_num_frames: Optional[int] = None, # Max frames processed in one segment by transformer (pixel space) - ar_step: int = 5, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, overlap_history: Optional[int] = None, + shift: float = 1.0, addnoise_condition: float = 0.0, - ) -> Union[SkyReelsV2PipelineOutput, Tuple]: + base_num_frames: int = 97, + ar_step: int = 5, + causal_block_size: Optional[int] = None, + fps: int = 24, + ): r""" - Generate video frames conditioned on text prompts and optionally on specific input frames (diffusion forcing). + The call function to the pipeline for generation. Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide video generation. If not defined, prompt_embeds must be. - conditioning_frames (`List[PIL.Image.Image]` or `torch.Tensor`, *optional*): - Frames to condition on. Must be provided if `conditioning_frame_mask` is provided. If a list, should - contain PIL Images. If a Tensor, assumes shape compatible with VAE input after batching. - conditioning_frame_mask (`List[int]`, *optional*): - A list of 0s and 1s with length `num_frames`. 1 indicates a conditioning frame, 0 indicates a frame to - generate. - num_frames (`int`, *optional*, defaults to 97): - The total number of frames to generate in the video sequence. - height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated video. - width (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated video. - num_inference_steps (`int`, *optional*, defaults to 30): - The number of denoising steps. - guidance_scale (`float`, *optional*, defaults to 6.0): - Guidance scale for classifier-free guidance. Enabled when > 1. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. negative_prompt (`str` or `List[str]`, *optional*): - Negative prompts for CFG. + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos to generate per prompt. + The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - PyTorch Generator object(s) for deterministic generation. + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated initial latents (noise). If provided, shape should match expected latent shape. + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. - max_sequence_length (`int`, *optional*): - Maximum sequence length for tokenizer. Defaults to model max length (e.g., 77). + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): - Output format: `"tensor"` (torch.Tensor) or `"np"` (list of np.ndarray). + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether to return `SkyReelsV2PipelineOutput` or a tuple. - callback (`Callable`, *optional*): - Callback function called every `callback_steps` steps. - callback_steps (`int`, *optional*, defaults to 1): - Frequency of callback calls. - cross_attention_kwargs (`dict`, *optional*): - Keyword arguments passed to the attention processor. - fps (`int`, *optional*, defaults to 24): - Target frames per second for the video, passed to the transformer if supported. - shift (`float`, *optional*, defaults to 8.0): - Shift parameter for the `FlowUniPCMultistepScheduler` (if used as main scheduler). - base_num_frames (`int`, *optional*): - Maximum number of frames the transformer processes in a single segment when using original-style long - video generation. If None or if `num_frames` is less than this, processes `num_frames`. Corresponds to - `base_num_frames` in original SkyReels-V2 (pixel space). - ar_step (`int`, *optional*, defaults to 5): - Autoregressive step size used in `_generate_timestep_matrix` for scheduling timesteps across frames - within a segment. - overlap_history (`int`, *optional*): - Number of frames to overlap between segments for long video generation. If None, long video generation - with overlap is disabled. Uses pixel frame count. - addnoise_condition (`float`, *optional*, defaults to 0.0): - Controls the amount of noise added to conditioned latents (prefix or user-provided) during the - denoising loop when using original-style forcing. A value > 0 enables it. Corresponds to - `addnoise_condition` in SkyReels-V2. + Whether or not to return a [`SkyReelsV2DiffusionForcingPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + shift (`float`, *optional*, defaults to `5.0`): + The shift of the flow. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + Examples: Returns: - [`~pipelines.skyreels_v2.pipeline_skyreels_v2_text_to_video.SkyReelsV2PipelineOutput`] or `tuple`. + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 0. Require height and width & VAE spatial scale factor - if height is None or width is None: - raise ValueError("Please provide `height` and `width` for video generation.") - vae_spatial_scale_factor = self.vae_scale_factor - height = height - height % vae_spatial_scale_factor - width = width - width % vae_spatial_scale_factor - if height == 0 or width == 0: - raise ValueError( - f"Provided height {height} and width {width} are too small. Must be divisible by {vae_spatial_scale_factor}." - ) - # Determine VAE temporal downsample factor - if hasattr(self.vae.config, "temporal_downsample") and self.vae.config.temporal_downsample is not None: - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temporal_downsample if td) - vae_temporal_scale_factor = 2**num_true_temporal_downsamples - elif ( - hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None - ): # Typo in some old configs - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) - vae_temporal_scale_factor = 2**num_true_temporal_downsamples - logger.warning("VAE config has misspelled 'temperal_downsample'. Using it.") - else: - vae_temporal_scale_factor = 4 # Default if not specified - logger.warning( - f"VAE config does not specify 'temporal_downsample'. Using default temporal_downsample_factor={vae_temporal_scale_factor}." - ) - def to_latent_frames(pixel_frames): - if pixel_frames is None or pixel_frames <= 0: - return 0 - return (pixel_frames - 1) // vae_temporal_scale_factor + 1 + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - num_latent_frames_total = to_latent_frames(num_frames) - if num_latent_frames_total <= 0: - raise ValueError( - f"num_frames {num_frames} results in {num_latent_frames_total} latent frames. Must be > 0." + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) - # Determine causal_block_size for _generate_timestep_matrix from transformer config or default - causal_block_size = getattr(self.transformer.config, "causal_block_size", 1) - if not isinstance(causal_block_size, int) or causal_block_size <= 0: - causal_block_size = 1 + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False - # 1. Check inputs - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - self.check_conditioning_inputs(conditioning_frames, conditioning_frame_mask, num_frames) # Will need review - has_initial_conditioning = conditioning_frames is not None + device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - else: # prompt_embeds must be provided - batch_size = prompt_embeds.shape[0] // num_videos_per_prompt # Correct batch_size from prompt_embeds - - device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 + else: + batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, # Pass through pre-generated embeds - negative_prompt_embeds=negative_prompt_embeds, # Pass through + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, + device=device, ) - effective_batch_size = batch_size * num_videos_per_prompt - # 4. Prepare scheduler and timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # 5. Prepare initial conditioning information from user-provided frames for the ENTIRE video duration. - # This section prepares data structures that will be used by both short and long video paths - # to incorporate user-specified conditioning frames using the `addnoise_condition` logic. - - initial_clean_conditioning_latents = None - # Stores VAE encoded clean latents from user `conditioning_frames` at their respective - # positions in the full video timeline. Shape: (eff_batch_size, C, num_latent_frames_total, H_latent, W_latent) - - initial_conditioning_latent_mask = torch.zeros(num_latent_frames_total, dtype=torch.bool, device=device) - # Boolean mask indicating which *latent frames* along the total video duration are directly conditioned by user input. - # True if a latent frame `i` has user-provided conditioning data in `initial_clean_conditioning_latents`. - - num_latent_frames_pre_ready_from_user = 0 - # This specific variable counts how many *contiguous latent frames from the very beginning* of the video - # are to be considered "pre-ready" or "frozen" for the `_generate_timestep_matrix` function. - # For typical diffusion forcing where specific frames are conditioned (not necessarily a prefix), - # this will often be 0. True video prefixes would set this. - # All other user-specified conditioning (sparse, non-prefix) will be handled by `addnoise_condition` logic - # guided by `initial_conditioning_latent_mask` within the denoising loops. - - if has_initial_conditioning: - if conditioning_frame_mask is None: - raise ValueError("If conditioning_frames are provided, conditioning_frame_mask must also be provided.") - if len(conditioning_frame_mask) != num_frames: - raise ValueError( - f"conditioning_frame_mask length ({len(conditioning_frame_mask)}) must equal num_frames ({num_frames})." - ) + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) - # Encode the user-provided frames. self.encode_frames is expected to return appropriately batched latents. - # Assuming self.encode_frames returns: (effective_batch_size, C, N_sparse_encoded_latents, Hl, Wl) - # where N_sparse_encoded_latents matches the number of 1s in conditioning_frame_mask (potentially after VAE temporal compression). - sparse_user_latents = self.encode_frames(conditioning_frames) - - if sparse_user_latents.shape[0] != effective_batch_size: - if sparse_user_latents.shape[0] == 1 and effective_batch_size > 1: - sparse_user_latents = sparse_user_latents.repeat(effective_batch_size, 1, 1, 1, 1) - else: - raise ValueError( - f"Batch size mismatch: encoded conditioning frames have batch {sparse_user_latents.shape[0]}, expected {effective_batch_size}." - ) - - latent_channels_for_cond = self.vae.config.latent_channels # Should match sparse_user_latents.shape[1] - latent_height_for_cond = sparse_user_latents.shape[-2] - latent_width_for_cond = sparse_user_latents.shape[-1] - - initial_clean_conditioning_latents = torch.zeros( - effective_batch_size, - latent_channels_for_cond, - num_latent_frames_total, - latent_height_for_cond, - latent_width_for_cond, - dtype=sparse_user_latents.dtype, - device=device, + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 ) + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) - processed_sparse_count = 0 - # Map the 1s in pixel-space `conditioning_frame_mask` to latent frame indices - # and place the corresponding `sparse_user_latents`. - for pixel_idx, is_cond_pixel in enumerate(conditioning_frame_mask): - if is_cond_pixel == 1: - if processed_sparse_count >= sparse_user_latents.shape[2]: - logger.warning( - f"More 1s in conditioning_frame_mask than available encoded conditioning frames ({sparse_user_latents.shape[2]})." - ) - break # Stop if we've run out of provided conditioning latents - - # Determine the target latent frame index for this pixel-space conditioned frame - # to_latent_frames expects 1-indexed pixel frame, returns 1-indexed latent frame count up to that point. - # So, for a pixel_idx (0-indexed), its corresponding latent frame index (0-indexed) is to_latent_frames(pixel_idx + 1) - 1. - target_latent_idx = to_latent_frames(pixel_idx + 1) - 1 - - if 0 <= target_latent_idx < num_latent_frames_total: - initial_clean_conditioning_latents[:, :, target_latent_idx, :, :] = sparse_user_latents[ - :, :, processed_sparse_count, :, : - ] - initial_conditioning_latent_mask[target_latent_idx] = True - processed_sparse_count += 1 - else: - logger.warning( - f"Pixel frame {pixel_idx} maps to latent index {target_latent_idx} out of bounds [0, {num_latent_frames_total - 1}]. Skipping." - ) - - if processed_sparse_count < sparse_user_latents.shape[2]: - logger.warning( - f"Only used {processed_sparse_count} out of {sparse_user_latents.shape[2]} provided conditioning latents. " - "Ensure conditioning_frame_mask aligns with the number of conditioning_frames provided and video length." - ) + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) - # For num_latent_frames_pre_ready_from_user: count contiguous conditioned frames from latent start. - # This is specifically for _generate_timestep_matrix's `num_pre_ready` which expects a prefix. - # Other sparse conditioning is handled by `addnoise_condition` using the mask and clean latents. - current_pre_ready_count = 0 - for i in range(num_latent_frames_total): - if initial_conditioning_latent_mask[i]: - current_pre_ready_count += 1 - else: - break - num_latent_frames_pre_ready_from_user = current_pre_ready_count - if num_latent_frames_pre_ready_from_user > 0: - logger.info( - f"{num_latent_frames_pre_ready_from_user} latent frames from the start are user-conditioned and will be treated as pre-ready." - ) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue - # Latent Dims - num_channels_latents = self.vae.config.latent_channels - latent_height = height // vae_spatial_scale_factor - latent_width = width // vae_spatial_scale_factor - - # Determine if using long video path - use_long_video_path = False - base_latent_frames_seg = num_latent_frames_total # Default to full length if not long video mode - overlap_latent_frames_seg = 0 - - if overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames: - if base_num_frames <= 0: - raise ValueError("base_num_frames must be positive.") - if overlap_history < 0: - raise ValueError("overlap_history must be non-negative.") - if overlap_history >= base_num_frames: - raise ValueError("overlap_history must be < base_num_frames.") - - # Check if long video generation is actually needed after converting to latent frames - base_latent_frames_seg = to_latent_frames(base_num_frames) - overlap_latent_frames_seg = to_latent_frames(overlap_history) - - if base_latent_frames_seg <= 0: - base_latent_frames_seg = 1 # Ensure minimum segment length - # Overlap can be 0 in latent space even if > 0 in pixel space, handle this - if overlap_latent_frames_seg < 0: - overlap_latent_frames_seg = 0 # Should not happen with to_latent_frames but safety - - if num_latent_frames_total > base_latent_frames_seg: - use_long_video_path = True - if overlap_latent_frames_seg >= base_latent_frames_seg: - logger.warning( - f"Calculated overlap_latent_frames ({overlap_latent_frames_seg}) >= base_latent_frames_seg ({base_latent_frames_seg}). Disabling overlap for long video." - ) - overlap_latent_frames_seg = 0 - - # Prepare initial latents for the full video (initial noise or user provided latents) - if latents is None: - shape = (effective_batch_size, num_channels_latents, num_latent_frames_total, latent_height, latent_width) - full_video_latents = randn_tensor(shape, generator=generator, device=device, dtype=prompt_embeds.dtype) - full_video_latents = full_video_latents * self.scheduler.init_noise_sigma - else: - expected_shape = ( - effective_batch_size, - num_channels_latents, - num_latent_frames_total, - latent_height, - latent_width, - ) - if latents.shape != expected_shape: - raise ValueError(f"Provided latents shape {latents.shape} does not match expected {expected_shape}.") - full_video_latents = latents.to(device, dtype=prompt_embeds.dtype) - - # Helper method for denoising a single segment - def _denoise_segment( - self, - segment_latents: torch.Tensor, # Latents for the current segment (slice of full_video_latents) - segment_start_global_idx: int, # Start index of this segment in the total video - num_latent_frames_this_segment: int, # Number of latent frames in this segment - num_pre_ready_for_this_segment: int, # Number of contiguous pre-ready frames at segment start - total_num_latent_frames: int, # Total latent frames in the whole video - initial_clean_conditioning_latents: Optional[torch.Tensor], # Clean conditioning for the whole video - initial_conditioning_latent_mask: Optional[torch.Tensor], # Mask for conditioned frames in the whole video - addnoise_condition: float, - timesteps: torch.Tensor, - prompt_embeds: torch.Tensor, - guidance_scale: float, - do_classifier_free_guidance: bool, - cross_attention_kwargs: Optional[Dict[str, Any]], - causal_block_size: int, - generator: Optional[Union[torch.Generator, List[torch.Generator]]], - progress_bar, - callback: Optional[Callable[[int, int, torch.Tensor], None]], - callback_steps: int, - fps: Optional[int] = None, # Add fps parameter - # Optional: segment index for logging - segment_index: int = 0, - num_segments: int = 1, - ) -> torch.Tensor: - # This method encapsulates the denoising loop logic previously in the short video path - # It will denoise `segment_latents` in place or return updated latents. - - # Generate the timestep matrix for this segment - step_matrix, step_index_matrix, update_mask_matrix, valid_interval_list = self._generate_timestep_matrix( - num_latent_frames=num_latent_frames_this_segment, - step_template=timesteps, - base_latent_frames=num_latent_frames_this_segment, # Base is segment length for matrix calc - ar_step=ar_step, - num_latent_frames_pre_ready=num_pre_ready_for_this_segment, - causal_block_size=causal_block_size, - ) + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) - if ( - not step_matrix.numel() - and num_latent_frames_this_segment > 0 - and num_pre_ready_for_this_segment < num_latent_frames_this_segment - ): - # Check if step_matrix is empty but should not be (i.e. not all frames are pre-ready) - logger.warning( - f"Segment {segment_index + 1}/{num_segments}: _generate_timestep_matrix returned empty." - ) - # If no steps, latents remain as is. - return segment_latents # Return unchanged if no denoising steps generated - - # Denoising loop for the current segment - # The progress bar total is managed by the main loop (either short path or long video outer loop) - for i_matrix_step in range(len(step_matrix)): - current_timesteps_for_frames = step_matrix[i_matrix_step] # Timestamps for each frame in the segment - current_update_mask_for_frames = update_mask_matrix[ - i_matrix_step - ] # Update mask for each frame in the segment - valid_interval_start_local, valid_interval_end_local = valid_interval_list[ - i_matrix_step - ] # Local indices within segment - - # Slice segment latents for the current valid processing window - latent_model_input = segment_latents[ - :, :, valid_interval_start_local:valid_interval_end_local - ].clone() # Clone for modification - # Timesteps for the transformer input - corresponds to the sliced latents - timestep_tensor_for_transformer = current_timesteps_for_frames[ - valid_interval_start_local:valid_interval_end_local - ] - - # === Implement addnoise_condition logic === - if addnoise_condition > 0.0 and initial_clean_conditioning_latents is not None: - # Iterate over frames within the current valid interval slice (local index j_local) - for j_local in range(valid_interval_end_local - valid_interval_start_local): - # Map local segment index to global video index - j_global = segment_start_global_idx + valid_interval_start_local + j_local - - # Check if this global frame is user-conditioned AND NOT considered pre-ready/frozen by _generate_timestep_matrix. - # _generate_timestep_matrix should handle num_pre_ready by setting update_mask=False for those frames. - # So we apply addnoise to frames marked as conditioned (globally) AND are being updated in this matrix step. - if ( - j_global < total_num_latent_frames - and initial_conditioning_latent_mask[j_global] - and current_update_mask_for_frames[valid_interval_start_local + j_local] - ): - # This is a conditioned frame that's being processed, apply addnoise logic - # Get the clean conditioned frame from the global conditioning tensor - clean_cond_frame = initial_clean_conditioning_latents[:, :, j_global, :, :] - # Original code used 0.001 * addnoise_condition for noise factor. Let's use param directly. - noise_factor = addnoise_condition - - # Add noise to the clean conditioned frame - noise = randn_tensor( - clean_cond_frame.shape, - generator=generator, - device=device, - dtype=clean_cond_frame.dtype, - ) - noised_cond_frame = clean_cond_frame * (1.0 - noise_factor) + noise * noise_factor - - # Replace the noisy latent in the model input slice with the noised conditioned frame - latent_model_input[:, :, j_local] = noised_cond_frame - - # Original code also clamped the timestep for conditioned frames. - # Let's clamp the specific frame's timestep in the transformer input tensor. - # Use addnoise_condition value as the clamping threshold. - if addnoise_condition > 0: # Avoid min with 0 if addnoise_condition is 0 - # Ensure addnoise_condition is treated as a valid timestep index or value. - # Assuming addnoise_condition is intended as a timestep value or something to clamp against. - # clamped_timestep_value = torch.tensor(addnoise_condition, device=device, dtype=timestep_tensor_for_transformer.dtype) - # Original clamped: `timestep_tensor_for_transformer[j_local] = torch.min(timestep_tensor_for_transformer[j_local], clamped_timestep_value)` - # Let's remove timestep clamping based on `addnoise_condition` for now, as it's less standard in Diffusers schedulers - # and its exact effect in original is tied to their scheduler step logic. - # addnoise_condition will primarily function as a noise amount control. - pass # Clamping logic removed - - # === End of addnoise_condition logic === - - # Model input for transformer (potentially with CFG duplication) - model_input_for_transformer = latent_model_input - if do_classifier_free_guidance: - model_input_for_transformer = torch.cat([model_input_for_transformer] * 2) - - # Timesteps for transformer (duplicated for CFG) - if do_classifier_free_guidance: - timestep_tensor_for_transformer_cfg = torch.cat([timestep_tensor_for_transformer] * 2) - else: - timestep_tensor_for_transformer_cfg = timestep_tensor_for_transformer - - # Transformer forward pass - model_pred = self.transformer( - model_input_for_transformer, - timestep=timestep_tensor_for_transformer_cfg, # Use per-frame timesteps + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - # Pass fps_embeds if fps is provided - fps=torch.tensor([fps] * model_input_for_transformer.shape[0], device=self.device) - if fps is not None - else None, - ).sample - - # CFG guidance - if do_classifier_free_guidance: - model_pred_uncond, model_pred_text = model_pred.chunk(2) - model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - - # Scheduler step per frame if updated - # Iterate over the frames in the current valid interval (local index idx_local_in_segment) - for idx_local_in_segment in range(model_pred.shape[2]): - # Global frame index corresponding to this local segment index - # g_idx = segment_start_global_idx + valid_interval_start_local + idx_local_in_segment # Not directly used in indexing below - - # Check the update mask *for the corresponding frame in the current segment* to see if it should be updated - # update_mask_matrix is (num_matrix_steps, num_latent_frames_this_segment) - # The update mask for the current frame is at `current_update_mask_for_frames[valid_interval_start_local + idx_local_in_segment]` - # which is equivalent to `current_update_mask_for_frames[idx_local_in_segment]` within the sliced valid interval - if current_update_mask_for_frames[valid_interval_start_local + idx_local_in_segment]: - frame_pred = model_pred[:, :, idx_local_in_segment] # Prediction for this frame (local index) - frame_latent = segment_latents[ - :, :, valid_interval_start_local + idx_local_in_segment - ] # Current latent for this frame (local in segment) - frame_timestep = current_timesteps_for_frames[ - valid_interval_start_local + idx_local_in_segment - ] # Timestep for this frame from matrix - - # Apply scheduler step for this single frame's latent - # Update the corresponding frame directly in the segment_latents tensor - segment_latents[:, :, valid_interval_start_local + idx_local_in_segment] = self.scheduler.step( - frame_pred, - frame_timestep, - frame_latent, - return_dict=False, - generator=generator, # Pass generator - )[0] - - # Progress bar update - handled by the outer loop caller - - # Callback - handled by the outer loop caller - - return segment_latents # Return the denoised segment latents - - # Main generation loop(s) - if not use_long_video_path: - logger.info(f"Short video path: {num_latent_frames_total} latent frames.") - # Denoise the full video (single segment) - denoised_latents = _denoise_segment( - self, - segment_latents=full_video_latents, # Denoise the whole video - segment_start_global_idx=0, # Starts at the beginning of the video - num_latent_frames_this_segment=num_latent_frames_total, # Segment is the whole video - num_pre_ready_for_this_segment=num_latent_frames_pre_ready_from_user, # Use user-provided pre-ready count - total_num_latent_frames=num_latent_frames_total, - initial_clean_conditioning_latents=initial_clean_conditioning_latents, - initial_conditioning_latent_mask=initial_conditioning_latent_mask, - addnoise_condition=addnoise_condition, - timesteps=timesteps, - prompt_embeds=prompt_embeds, - guidance_scale=guidance_scale, - do_classifier_free_guidance=do_classifier_free_guidance, - cross_attention_kwargs=cross_attention_kwargs, - causal_block_size=causal_block_size, - generator=generator, - progress_bar=None, # Progress bar handled below the if block if needed, or manage here - callback=callback, - callback_steps=callback_steps, - segment_index=0, - num_segments=1, # For logging in helper + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) ) - # The progress bar and callback for the short video path need to be managed around the _denoise_segment call - # Or, pass the progress_bar and callback down and manage inside _denoise_segment. - # Let's manage the progress bar and callback *inside* _denoise_segment. - # Need to pass the progress_bar object to _denoise_segment. - # Refactor: pass progress_bar and callback objects to _denoise_segment. - - # Rerun _denoise_segment call with progress_bar and callback passed: - logger.info("Short video path: Starting denoising.") - with self.progress_bar(total=len(timesteps)) as progress_bar: - # Need to determine the actual number of matrix steps for progress bar total - # Call _generate_timestep_matrix first to get matrix size - temp_step_matrix, _, _, _ = self._generate_timestep_matrix( - num_latent_frames=num_latent_frames_total, - step_template=timesteps, - base_latent_frames=num_latent_frames_total, # For short path, base is the full length - ar_step=ar_step, - num_latent_frames_pre_ready=num_latent_frames_pre_ready_from_user, - causal_block_size=causal_block_size, - ) - progress_bar.total = ( - len(temp_step_matrix) if len(temp_step_matrix) > 0 else num_inference_steps - ) # Adjusted total - if ( - progress_bar.total == 0 - and num_latent_frames_total > 0 - and num_latent_frames_pre_ready_from_user < num_latent_frames_total - ): - logger.warning( - "Progress bar total is 0 but video needs denoising. Setting total to num_inference_steps." - ) - progress_bar.total = num_inference_steps - - denoised_latents = _denoise_segment( - self, - segment_latents=full_video_latents, # Denoise the whole video - segment_start_global_idx=0, # Starts at the beginning of the video - num_latent_frames_this_segment=num_latent_frames_total, # Segment is the whole video - num_pre_ready_for_this_segment=num_latent_frames_pre_ready_from_user, # Use user-provided pre-ready count - total_num_latent_frames=num_latent_frames_total, - initial_clean_conditioning_latents=initial_clean_conditioning_latents, - initial_conditioning_latent_mask=initial_conditioning_latent_mask, - addnoise_condition=addnoise_condition, - timesteps=timesteps, - prompt_embeds=prompt_embeds, - guidance_scale=guidance_scale, - do_classifier_free_guidance=do_classifier_free_guidance, - cross_attention_kwargs=cross_attention_kwargs, - causal_block_size=causal_block_size, - generator=generator, - progress_bar=progress_bar, # Pass the segment progress bar - callback=callback, # Pass callback - callback_steps=callback_steps, - segment_index=0, - num_segments=1, - ) - - else: # Long video path - Implementation - logger.info( - f"Long video path: {num_latent_frames_total} total latents, {base_latent_frames_seg} base/segment, {overlap_latent_frames_seg} overlap." + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents - # Calculate number of segments - non_overlapping_part_len = base_latent_frames_seg - overlap_latent_frames_seg - if non_overlapping_part_len <= 0: - logger.error("Non-overlapping part of segment is <=0. Adjust base_num_frames or overlap_history.") - raise ValueError("Non-overlapping part of segment must be positive.") - - num_iterations = ( - 1 - + (num_latent_frames_total - base_latent_frames_seg + non_overlapping_part_len - 1) - // non_overlapping_part_len - ) - logger.info(f"Long video: processing in {num_iterations} segments.") - - # Initialize tensor to store the final denoised latents for the whole video - # final_denoised_latents = torch.zeros_like(full_video_latents) # This was unused - # Or accumulate in a list and concatenate later, might be better to avoid pre-allocation issues. - # Let's accumulate in a list for now. - final_denoised_segments_list = [] - - # Keep track of the last part of the previously denoised segment for overlap - previous_segment_denoised_overlap_latents = None # Will be (B, C, overlap_latent_frames_seg, H_l, W_l) - - # Main loop for processing segments - # Total progress bar should cover all matrix steps across all segments. - # This is hard to pre-calculate exactly due to matrix generation per segment. - # Alternative: progress bar tracks segments, and each segment denoising shows internal progress (if tqdm nested allowed). - # Let's track segments in the main progress bar for simplicity first. - - with self.progress_bar( - total=num_iterations, desc="Generating Long Video Segments" - ) as progress_bar_segments: - for i_segment in range(num_iterations): - # Determine the global latent frame indices for the current segment - segment_start_global_latent_idx = i_segment * non_overlapping_part_len - # The end index is start + base_segment_length, clamped to total length - segment_end_global_latent_idx = min( - segment_start_global_latent_idx + base_latent_frames_seg, num_latent_frames_total - ) - - # Adjust start index to include overlap from previous segment (if not the first segment) - current_segment_global_start_with_overlap = segment_start_global_latent_idx - if i_segment > 0: - current_segment_global_start_with_overlap -= overlap_latent_frames_seg - - # Determine the actual number of latent frames in this current segment - num_latent_frames_this_segment = ( - segment_end_global_latent_idx - current_segment_global_start_with_overlap - ) - - logger.info( - f" Processing segment {i_segment + 1}/{num_iterations} (global latent frames {current_segment_global_start_with_overlap} to {segment_end_global_latent_idx - 1})." - ) - - # Prepare latents for the current segment - # Start with the initial noise (or user provided) for this segment's range - current_segment_latents = full_video_latents[ - :, :, current_segment_global_start_with_overlap:segment_end_global_latent_idx - ].clone() - - # If there's overlap from the previous segment, overwrite the initial part with the denoised overlap - num_pre_ready_for_this_segment = ( - num_latent_frames_pre_ready_from_user # Start with user prefix count - ) - if i_segment > 0 and previous_segment_denoised_overlap_latents is not None: - # Overwrite the first `overlap_latent_frames_seg` of current_segment_latents - # with the denoised overlap from the previous segment. - # The number of pre-ready frames for the matrix generation should be the overlap length. - current_segment_latents[:, :, :overlap_latent_frames_seg] = ( - previous_segment_denoised_overlap_latents - ) - num_pre_ready_for_this_segment = ( - overlap_latent_frames_seg # Overlap serves as the frozen prefix for matrix generation - ) - logger.info( - f" Segment includes {overlap_latent_frames_seg} latent frames of overlap from previous segment." - ) - elif i_segment == 0 and num_latent_frames_pre_ready_from_user > 0: - # First segment, use user-provided prefix as pre-ready - num_pre_ready_for_this_segment = num_latent_frames_pre_ready_from_user - - # Denoise the current segment - # Pass the segment_latents to the helper function. - # The helper will operate on this tensor and return the denoised result for the segment. - denoised_segment_latents = _denoise_segment( - self, - segment_latents=current_segment_latents, # Latents for this segment - segment_start_global_idx=current_segment_global_start_with_overlap, # Global start index (including overlap) - num_latent_frames_this_segment=num_latent_frames_this_segment, # Length of this segment - num_pre_ready_for_this_segment=num_pre_ready_for_this_segment, # Pre-ready frames for matrix generation - total_num_latent_frames=num_latent_frames_total, - initial_clean_conditioning_latents=initial_clean_conditioning_latents, # Full video conditioning - initial_conditioning_latent_mask=initial_conditioning_latent_mask, # Full video mask - addnoise_condition=addnoise_condition, - timesteps=timesteps, - prompt_embeds=prompt_embeds, - guidance_scale=guidance_scale, - do_classifier_free_guidance=do_classifier_free_guidance, - cross_attention_kwargs=cross_attention_kwargs, - causal_block_size=causal_block_size, - generator=generator, - progress_bar=progress_bar_segments, # Pass the segment progress bar - callback=callback, # Pass callback - callback_steps=callback_steps, - fps=fps, # Pass fps from __call__ scope - segment_index=i_segment, - num_segments=num_iterations, - ) - - # Extract the non-overlapping part of the denoised segment - # For the first segment, this is from the start up to base_latent_frames_seg. - # For subsequent segments, this is from overlap_latent_frames_seg onwards. - non_overlapping_segment_start_local_idx = 0 if i_segment == 0 else overlap_latent_frames_seg - non_overlapping_segment_latents = denoised_segment_latents[ - :, :, non_overlapping_segment_start_local_idx: - ].clone() - - # Add the non-overlapping part to the final list - final_denoised_segments_list.append(non_overlapping_segment_latents) - - # Prepare overlap for the next segment (if not the last segment) - if i_segment < num_iterations - 1: - # The overlap is the last `overlap_latent_frames_seg` of the *denoised* current segment. - overlap_start_local_idx = num_latent_frames_this_segment - overlap_latent_frames_seg - overlap_latents_to_process = denoised_segment_latents[:, :, overlap_start_local_idx:].clone() - - # Implementing original SkyReels V2 overlap handling (decode and re-encode): - # 1. Decode overlap latents to pixel space - # VAE decode expects (B, C, F, H, W). overlap_latents_to_process is (B, C, F_overlap, H_l, W_l) - decoded_overlap_pixels = self.vae.decode(overlap_latents_to_process).sample - # decoded_overlap_pixels is (B, F_overlap, C, H, W) after vae.decode (check vae_outputs) - - # 2. Re-encode pixel frames back to latent space - # VAE encode expects (B, C, F, H, W), so permute decoded pixels - # decoded_overlap_pixels needs permuting from (B, F_overlap, C, H, W) to (B, C, F_overlap, H, W) - encoded_overlap_latents = self.vae.encode( - decoded_overlap_pixels.permute(0, 2, 1, 3, 4) # Permute to (B, C, F, H, W) - ).latent_dist.sample() - - # Apply VAE scaling factor after encoding - previous_segment_denoised_overlap_latents = ( - encoded_overlap_latents * self.vae.config.scaling_factor - ) - - # Update segment progress bar - progress_bar_segments.update(1) # Update by 1 for each completed segment - - # Concatenate all denoised segments to get the full video latents - denoised_latents = torch.cat( - final_denoised_segments_list, dim=2 - ) # Concatenate along the frame dimension (dim=2) - - # 7. Post-processing (decode final denoised_latents) - video_tensor = self._decode_latents(denoised_latents) - video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) - + # Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index f804d43e6928..11e2aea8a5d7 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -1,4 +1,4 @@ -# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,363 +12,381 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL.Image +import PIL +import regex as re import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel -from ...image_processor import VideoProcessor +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowUniPCMultistepScheduler -from ...utils import ( - logging, - replace_example_docstring, -) +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from .pipeline_skyreels_v2_text_to_video import SkyReelsV2PipelineOutput +from .pipeline_output import SkyReelsV2PipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_ftfy_available(): + import ftfy EXAMPLE_DOC_STRING = """ Examples: - ```py + ```python >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2ImageToVideoPipeline - >>> from diffusers.utils import load_image, export_to_video - - >>> # Load the pipeline + >>> import numpy as np + >>> from diffusers import AutoencoderKLWan, SkyReelsV2ImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + >>> from transformers import CLIPVisionModel + + >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers + >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + >>> image_encoder = CLIPVisionModel.from_pretrained( + ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + ... ) + >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... "HF_placeholder/SkyReels-V2-I2V-14B-540P", torch_dtype=torch.float16 + ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 ... ) - >>> pipe = pipe.to("cuda") - - >>> # Load the conditioning image - >>> image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" # Example image - >>> image = load_image(image_url) + >>> pipe.to("cuda") - >>> prompt = "A cat running across the grass" - >>> video_frames = pipe(prompt=prompt, image=image, num_frames=97).frames - >>> export_to_video(video_frames, "skyreels_v2_i2v.mp4") + >>> image = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> max_area = 480 * 832 + >>> aspect_ratio = image.height / image.width + >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + >>> image = image.resize((width, height)) + >>> prompt = ( + ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " + ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + ... ) + >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + + >>> output = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=height, + ... width=width, + ... num_frames=81, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=16) ``` """ -class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline): - """ - Pipeline for image-to-video generation using SkyReels-V2. +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a specific device, etc.). - The pipeline is based on the Wan 2.1 architecture (WanTransformer3DModel, AutoencoderKLWan). It uses a - `CLIPVisionModelWithProjection` to encode the conditioning image. It expects checkpoints saved in the standard - diffusers format, typically including subfolders: `vae`, `text_encoder`, `tokenizer`, `image_encoder`, - `transformer`, `scheduler`. +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + r""" + Pipeline for image-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) model capable of encoding and decoding videos in latent space. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder (e.g., CLIP). - tokenizer ([`~transformers.CLIPTokenizer`]): - Tokenizer corresponding to the `text_encoder`. - image_encoder ([`~transformers.CLIPVisionModelWithProjection`]): - Frozen image encoder (e.g., CLIP Vision Model) to encode the conditioning image. - image_processor ([`~transformers.CLIPImageProcessor`]): - Image processor corresponding to the `image_encoder`. + tokenizer ([`T5Tokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. transformer ([`WanTransformer3DModel`]): - The core diffusion transformer model that denoises latents based on text and image conditioning. + Conditional Transformer to denoise the input latents. scheduler ([`FlowUniPCMultistepScheduler`]): - A scheduler compatible with the Flow Matching framework used by SkyReels-V2. - video_processor ([`VideoProcessor`]): - Processor for converting VAE output latents to standard video formats (np, tensor, pil). + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "image_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, - vae: AutoencoderKLWan, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - image_encoder: CLIPVisionModelWithProjection, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, + vae: AutoencoderKLWan, scheduler: FlowUniPCMultistepScheduler, - video_processor: VideoProcessor, ): super().__init__() + self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, image_encoder=image_encoder, - image_processor=image_processor, transformer=transformer, scheduler=scheduler, - video_processor=video_processor, + image_processor=image_processor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - # VaeImageProcessor is not needed here as CLIPImageProcessor handles image preprocessing. - - def _encode_image( - self, - image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - dtype: torch.dtype, - ) -> torch.Tensor: - """ - Encodes the input image using the image encoder. - - Args: - image (`torch.Tensor`, `PIL.Image.Image`, `List[PIL.Image.Image]`): - Image or batch of images to encode. - device (`torch.device`): Target device. - num_videos_per_prompt (`int`): Number of videos per prompt (for repeating embeddings). - do_classifier_free_guidance (`bool`): Whether to generate negative embeddings. - dtype (`torch.dtype`): Target data type for embeddings. - - Returns: - `torch.Tensor`: Encoded image embeddings. - """ - if isinstance(image, PIL.Image.Image): - image = [image] # Processor expects a list - - # Preprocess image(s) - image_pixels = self.image_processor(image, return_tensors="pt").pixel_values - image_pixels = image_pixels.to(device=device, dtype=dtype) - - # Get image embeddings - image_embeds = self.image_encoder(image_pixels).image_embeds # [batch_size, seq_len, embed_dim] - # Duplicate image embeddings for each generation per prompt - image_embeds = image_embeds.repeat_interleave(num_videos_per_prompt, dim=0) + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor - # Get negative embeddings for CFG - if do_classifier_free_guidance: - negative_image_embeds = torch.zeros_like(image_embeds) - image_embeds = torch.cat([negative_image_embeds, image_embeds]) + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - return image_embeds + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() + return prompt_embeds - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, - lora_scale: Optional[float] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): - """ + r""" Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide image generation. - device: (`torch.device`): - The torch device to place the resulting embeddings on. - num_videos_per_prompt (`int`): - The number of videos that should be generated per prompt. - do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance or not. + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - provide `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if - `guidance_scale` is less than 1). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. - lora_scale (`float`, *optional*): - Scale for LoRA-based text embeddings. + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype """ - # Set LoRA scale - lora_scale = lora_scale or self.lora_scale + device = device or self._execution_device - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - if max_sequence_length is None: - max_sequence_length = self.tokenizer.model_max_length - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, ) - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask.to(device), - ) - prompt_embeds = prompt_embeds[0] + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - if do_classifier_free_guidance: - if negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = "" - if isinstance(negative_prompt, str) and negative_prompt == "": - negative_prompt = [negative_prompt] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] - if isinstance(negative_prompt, list) and batch_size != len(negative_prompt): - raise ValueError("Negative prompt batch size mismatch") - uncond_input = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." ) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device) - )[0] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - bs_embed, seq_len, _ = negative_prompt_embeds.shape - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return prompt_embeds - def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - """ - Decode the generated latent sample using the VAE to produce video frames. + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) - Args: - latents (`torch.Tensor`): - Generated latent samples of shape (batch, channels, latent_frames, height, width). + return prompt_embeds, negative_prompt_embeds - Returns: - `torch.Tensor`: Decoded video frames of shape (batch, frames, channels, height, width) as a float tensor in - range [0, 1]. - """ - # AutoencoderKLWan expects B, C, F, H, W latents directly - video = self.vae.decode(latents).sample - video = video.permute(0, 2, 1, 3, 4) # B, F, C, H, W - video = (video / 2 + 0.5).clamp(0, 1) - return video + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, + image: PipelineImageInput, batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Prepare latent variables from noise for the diffusion process. - - Args: - batch_size (`int`): - Number of samples to generate. - num_channels_latents (`int`): - Number of channels in the latent space. - num_frames (`int`): - Number of video frames to generate. - height (`int`): - Height of the generated video in pixels. - width (`int`): - Width of the generated video in pixels. - dtype (`torch.dtype`): - Data type of the latent variables. - device (`torch.device`): - Device to generate the latents on. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video - generation. Can be used to tweak the same generation with different prompts. If not provided, a random - noisy latent is generated. - - Returns: - `torch.Tensor`: Prepared initial latent variables. - """ - vae_scale_factor = self.vae_scale_factor - shape_spatial = (batch_size, num_channels_latents, height // vae_scale_factor, width // vae_scale_factor) - shape_spatial = ( - batch_size, - num_channels_latents, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - - if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) - temporal_downsample_factor = 2**num_true_temporal_downsamples - else: - temporal_downsample_factor = 4 - logger.warning( - "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." - ) - - num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 - shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -378,227 +396,348 @@ def prepare_latents( if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - latents = latents.to(device) + latents = latents.to(device=device, dtype=dtype) - latents = latents * self.scheduler.init_noise_sigma - return latents + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + image: PipelineImageInput, prompt: Union[str, List[str]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]]] = None, - num_frames: int = 97, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 30, - guidance_scale: float = 6.0, - negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = 8.0, - ) -> Union[SkyReelsV2PipelineOutput, Tuple]: - """ + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" The call function to the pipeline for generation. Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): - The image to use as the starting point for the video generation. - num_frames (`int`, *optional*, defaults to 97): - The number of video frames to generate. - height (`int`, *optional*, defaults to None): - The height in pixels of the generated video frames. If not provided, height is automatically determined - from the model configuration. - width (`int`, *optional*, defaults to None): - The width in pixels of the generated video frames. If not provided, width is automatically determined - from the model configuration. - num_inference_steps (`int`, *optional*, defaults to 30): - The number of denoising steps. More denoising steps usually lead to higher quality videos at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 6.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of videos to generate per prompt. + The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - max_sequence_length (`int`, *optional*): - Maximum sequence length for input text when generating embeddings. If not provided, defaults to 77. + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated video. Choose between `tensor` and `np` for `torch.Tensor` or - `numpy.array` output respectively. + The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that calls every `callback_steps` steps during inference. The function is called with the - following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function is called. If not specified, the callback is called at - every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - custom_shift (`float`, *optional*): - Custom shifting factor to use in the flow matching framework. - + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + shift (`float`, *optional*, defaults to `5.0`): + The shift of the flow. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. Examples: - ```py - >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2ImageToVideoPipeline - - >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... "SkyworkAI/SkyReels-V2-DiffusionForcing-4.0B", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> image = PIL.Image.open("input_image.jpg").convert("RGB") - >>> prompt = "A beautiful view of mountains" - >>> video_frames = pipe(prompt, image=image, num_frames=16).frames[0] - ``` Returns: - [`~pipelines.SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.SkyReelsV2PipelineOutput`] is returned, otherwise a tuple is - returned where the first element is a list with the generated frames. + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 0. Default height and width to transformer dimensions - height = height or self.transformer.config.patch_size[1] * 112 # Default from SkyReels-V2: 224 - width = width or self.transformer.config.patch_size[2] * 112 # Default from SkyReels-V2: 224 - # 1. Check inputs + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, + negative_prompt, image, height, - num_frames, - callback_steps, - negative_prompt, + width, prompt_embeds, negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, ) - if image is None: - raise ValueError("For image-to-video generation, an input image is required.") + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) - elif image is not None: - if isinstance(image, PIL.Image.Image): - batch_size = 1 - elif isinstance(image, list) and all(isinstance(i, PIL.Image.Image) for i in image): - batch_size = len(image) - elif isinstance(image, torch.Tensor): - batch_size = image.shape[0] - else: - # Fallback or error if image type is not recognized for batch size inference - raise ValueError("Cannot determine batch size from the provided image type.") - elif prompt_embeds is not None: - batch_size = prompt_embeds.shape[0] else: - raise ValueError("Either `prompt`, `image`, or `prompt_embeds` must be provided.") - - device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 + batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, + device=device, ) - # 4. Encode input image - if image is None: - # This case should ideally be caught by check_inputs or initial ValueError - raise ValueError("`image` is a required argument for SkyReelsV2ImageToVideoPipeline.") + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - image_embeds = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance) + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps - # 6. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, batch_size * num_videos_per_prompt, num_channels_latents, - num_frames, height, width, - prompt_embeds.dtype, # Use prompt_embeds.dtype, image_embeds could be different + num_frames, + torch.float32, device, generator, - latents=latents, + latents, + last_image, ) - # 7. Denoising loop + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) - model_pred = self.transformer( - latent_model_input, - timestep=t, + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, - attention_kwargs=cross_attention_kwargs, - ).sample - - if do_classifier_free_guidance: - model_pred_uncond, model_pred_text = model_pred.chunk(2) - model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - - latents = self.scheduler.step(model_pred, t, latents).prev_sample + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - # 8. Post-processing - video_tensor = self.decode_latents(latents) - video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + # Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py deleted file mode 100644 index 2bf9169308c1..000000000000 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_text_to_video.py +++ /dev/null @@ -1,551 +0,0 @@ -# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -from transformers import CLIPTextModel, CLIPTokenizer - -from ...image_processor import VideoProcessor -from ...models import AutoencoderKLWan, WanTransformer3DModel -from ...schedulers import FlowUniPCMultistepScheduler -from ...utils import ( - BaseOutput, - logging, - replace_example_docstring, -) -from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import torch - >>> from diffusers import SkyReelsV2TextToVideoPipeline - >>> from diffusers.utils import export_to_video - - >>> # Load the pipeline - >>> pipe = SkyReelsV2TextToVideoPipeline.from_pretrained( - ... "HF_placeholder/SkyReels-V2-T2V-14B-540P", torch_dtype=torch.float16 - ... ) - >>> pipe = pipe.to("cuda") - - >>> prompt = "A panda eating bamboo on a rock, 4k, detailed" - >>> video_frames = pipe(prompt, num_frames=97).frames # Default num_frames is often higher for video - >>> export_to_video(video_frames, "skyreels_v2_t2v.mp4") - ``` -""" - - -@dataclass -class SkyReelsV2PipelineOutput(BaseOutput): - """ - Output class for SkyReels-V2 pipelines. - - Args: - frames (`List[np.ndarray]` or `torch.Tensor`): - List of video frames generated by the pipeline. Format depends on `output_type` argument. `np.ndarray` list - is default. For `output_type="np"`: list of `np.ndarray` of shape `(num_frames, height, width, - num_channels)` with values in [0, 255]. For `output_type="tensor"`: `torch.Tensor` of shape `(batch_size, - num_frames, channels, height, width)` with values in [0, 1]. For `output_type="pil"`: list of - `PIL.Image.Image`. - """ - - frames: Union[List[np.ndarray], torch.Tensor, List[PIL.Image.Image]] - - -class SkyReelsV2TextToVideoPipeline(DiffusionPipeline): - """ - Pipeline for text-to-video generation using SkyReels-V2. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a specific device, etc.). - - The pipeline is based on the Wan 2.1 architecture (WanTransformer3DModel, AutoencoderKLWan). It expects checkpoints - saved in the standard diffusers format, typically including subfolders: `vae`, `text_encoder`, `tokenizer`, - `transformer`, `scheduler`. - - Args: - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) model capable of encoding and decoding videos in latent space. Expected to - handle 3D inputs (temporal dimension). - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder. SkyReels-V2 typically uses CLIP (e.g., `openai/clip-vit-large-patch14`). - tokenizer ([`~transformers.CLIPTokenizer`]): - Tokenizer corresponding to the `text_encoder`. - transformer ([`WanTransformer3DModel`]): - The core diffusion transformer model that denoises latents based on text conditioning. - scheduler ([`FlowUniPCMultistepScheduler`]): - A scheduler compatible with the Flow Matching framework used by SkyReels-V2. - video_processor ([`VideoProcessor`]): - Processor for converting VAE output latents to standard video formats (np, tensor, pil). - """ - - model_cpu_offload_seq = "text_encoder->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - vae: AutoencoderKLWan, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - transformer: WanTransformer3DModel, - scheduler: FlowUniPCMultistepScheduler, - video_processor: VideoProcessor, - ): - super().__init__() - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - transformer=transformer, - scheduler=scheduler, - video_processor=video_processor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def encode_prompt( - self, - prompt: Union[str, List[str]], - device: torch.device, - num_videos_per_prompt: int, - do_classifier_free_guidance: bool, - negative_prompt: Optional[Union[str, List[str]]] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, - lora_scale: Optional[float] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide video generation. - device: (`torch.device`): - torch device. - num_videos_per_prompt (`int`): - Number of videos to generate per prompt. - do_classifier_free_guidance (`bool`): - Whether to use classifier-free guidance. - negative_prompt (`str` or `List[str]`, *optional*): - The negative prompt or prompts. Ignored if `do_classifier_free_guidance` is `False`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Higher priority than `prompt`. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Higher priority than `negative_prompt`. - max_sequence_length (`int`, *optional*): - Maximum sequence length for tokenizer. Defaults to `self.tokenizer.model_max_length`. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ - # Set LoRA scale - lora_scale = lora_scale or self.lora_scale - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizer parameters - if max_sequence_length is None: - max_sequence_length = self.tokenizer.model_max_length - - # Get prompt text embeddings - if prompt_embeds is None: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=text_inputs.attention_mask.to(device), - output_hidden_states=False, - )[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - # Get negative prompt embeddings - if do_classifier_free_guidance: - if negative_prompt_embeds is None: - if negative_prompt is None: - negative_prompt = "" # Use empty string - # Do not repeat for multiple prompts if it is empty string - if isinstance(negative_prompt, str) and negative_prompt == "": - negative_prompt = [negative_prompt] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] # Already handled single string case - - if isinstance(negative_prompt, list) and batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`: {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches batch size of `prompt`." - ) - - uncond_input = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - return_tensors="pt", - ) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=uncond_input.attention_mask.to(device), - output_hidden_states=False, - )[0] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - # duplicate negative embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = negative_prompt_embeds.shape - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: - """ - Decode video latents using the VAE. - - Args: - latents (`torch.Tensor`): Latents of shape (batch, channels, latent_frames, height, width). - - Returns: - `torch.Tensor`: Decoded video frames of shape (batch, frames, channels, height, width) as float tensor [0, - 1]. - """ - # AutoencoderKLWan expects B, C, F, H, W latents directly - video = self.vae.decode(latents).sample - - # Output is likely B, C, F, H, W in range [-1, 1] - # Convert to B, F, C, H, W and range [0, 1] - video = video.permute(0, 2, 1, 3, 4) # B, F, C, H, W - video = (video / 2 + 0.5).clamp(0, 1) - return video - - def prepare_latents( - self, - batch_size: int, - num_channels_latents: int, - num_frames: int, - height: int, - width: int, - dtype: torch.dtype, - device: torch.device, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Prepare latent variables from noise for the diffusion process. - - Args: - batch_size (`int`): - Number of samples to generate. - num_channels_latents (`int`): - Number of channels in the latent space (e.g., `self.vae.config.latent_channels`). - num_frames (`int`): - Number of video frames *in the final output*. - height (`int`): - Height of the generated video in pixels. - width (`int`): - Width of the generated video in pixels. - dtype (`torch.dtype`): - Data type for the latents. - device (`torch.device`): - Device for the latents. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - PyTorch Generator object(s). - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents. - - Returns: - `torch.Tensor`: Initial latent variables (noise) scaled by the scheduler's `init_noise_sigma`. - """ - vae_scale_factor = self.vae_scale_factor - shape_spatial = ( - batch_size, - num_channels_latents, - height // vae_scale_factor, - width // vae_scale_factor, - ) - - # Calculate temporal downsampling factor from VAE config - if hasattr(self.vae.config, "temperal_downsample") and self.vae.config.temperal_downsample is not None: - num_true_temporal_downsamples = sum(1 for td in self.vae.config.temperal_downsample if td) - temporal_downsample_factor = 2**num_true_temporal_downsamples - else: - temporal_downsample_factor = 4 # Default from original SkyReels - logger.warning( - "VAE config does not have 'temperal_downsample'. Using default temporal_downsample_factor=4." - ) - - # Calculate the number of latent frames - num_latent_frames = (num_frames - 1) // temporal_downsample_factor + 1 - shape = (shape_spatial[0], shape_spatial[1], num_latent_frames, shape_spatial[2], shape_spatial[3]) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape: {latents.shape}. Expected {shape}.") - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - num_frames: int = 97, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 30, - guidance_scale: float = 6.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: Optional[int] = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, - output_type: Optional[str] = "np", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - custom_shift: Optional[float] = 8.0, - ) -> Union[SkyReelsV2PipelineOutput, Tuple]: - """ - The call function to the pipeline for text-to-video generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt(s) to guide video generation. - num_frames (`int`, *optional*, defaults to 97): - The number of frames to generate. - height (`int`, *optional*): - The height in pixels of the generated video. Defaults to VAE output size. - width (`int`, *optional*): - The width in pixels of the generated video. Defaults to VAE output size. - num_inference_steps (`int`, *optional*, defaults to 30): - The number of denoising steps. - guidance_scale (`float`, *optional*, defaults to 6.0): - Scale for classifier-free guidance. `guidance_scale <= 1` disables CFG. - negative_prompt (`str` or `List[str]`, *optional*): - The negative prompt(s) for CFG. - num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos to generate per prompt. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - Generator(s) for deterministic generation. - latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents. Shape should match `(batch_size * num_videos_per_prompt, C, F_latent, - H_latent, W_latent)`. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Shape `(batch_size * num_videos_per_prompt, seq_len, embed_dim)`. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Shape `(batch_size * num_videos_per_prompt, seq_len, - embed_dim)`. - max_sequence_length (`int`, *optional*): - Maximum sequence length for tokenizer. Defaults to `self.tokenizer.model_max_length`. - output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated video. Choose between `"np"` (list of np.ndarray), `"tensor"` - (torch.Tensor), or `"pil"` (list of PIL.Image.Image). - return_dict (`bool`, *optional*, defaults to `True`): - Whether to return a [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`] or a tuple. - callback (`Callable`, *optional*): - A function called every `callback_steps` steps during inference. - callback_steps (`int`, *optional*, defaults to 1): - Frequency of callback calls. - cross_attention_kwargs (`dict`, *optional*): - Arguments passed to the attention processor. - custom_shift (`float`, *optional*, defaults to 8.0): - The "shift" parameter for the `FlowUniPCMultistepScheduler`. Controls emphasis on diffusion trajectory - parts. Corresponds to `shift` in the original SkyReels repository. - - Examples: - ```py - >>> # Example usage is included in the EXAMPLE_DOC_STRING variable - ``` - - Returns: - [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, returns [`~pipelines.skyreels_v2.SkyReelsV2PipelineOutput`]. Otherwise - returns a tuple `(video,)` where `video` format depends on `output_type`. - """ - # 0. Default height and width to VAE constraints - # Note: WanTransformer3DModel config doesn't have a standard 'sample_size'. - # Relying on VAE scale factor might be sufficient if input H/W are provided or inferred. - # Let's keep the defaults based on user input or raise error if not determinable. - if height is None or width is None: - # Height and width are required for this pipeline. - raise ValueError("Please provide `height` and `width` for video generation.") - - # Ensure height and width are multiples of VAE scale factor - height = height - height % self.vae_scale_factor - width = width - width % self.vae_scale_factor - if height == 0 or width == 0: - raise ValueError("Provided height and width are too small.") - - # 1. Check inputs - self.check_inputs( - prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # Implement LoRA scale handling - requires PeftAdapterMixin setup if LoRA is used - lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - do_classifier_free_guidance = guidance_scale > 1.0 - - # 3. Encode input prompt - prompt_embeds = self.encode_prompt( - prompt, - device, - num_videos_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - lora_scale=lora_scale, - ) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=custom_shift) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - num_frames, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents=latents, - ) - - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - model_pred = self.transformer( - latent_model_input, - timestep=t, - encoder_hidden_states=prompt_embeds, - attention_kwargs=cross_attention_kwargs, - ).sample - - # perform guidance - if do_classifier_free_guidance: - model_pred_uncond, model_pred_text = model_pred.chunk(2) - model_pred = model_pred_uncond + guidance_scale * (model_pred_text - model_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(model_pred, t, latents).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - # 7. Post-processing - video_tensor = self.decode_latents(latents) # B, F, C, H, W float [0,1] - - # 8. Process video output - video = self.video_processor.postprocess_video(video_tensor, output_type=output_type) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (video,) - - return SkyReelsV2PipelineOutput(frames=video) From e781084e93c2298fff81b5e0f816709d6ec0826f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 18:19:49 +0300 Subject: [PATCH 007/264] upup --- .../skyreels_v2/pipeline_skyreels_v2.py | 8 +- .../pipeline_skyreels_v2_diffusion_forcing.py | 6 +- .../pipeline_skyreels_v2_image_to_video.py | 6 +- src/diffusers/schedulers/__init__.py | 4 +- .../scheduling_flow_match_unipc_multistep.py | 778 ++++++++++++++++++ .../scheduling_flow_unipc_multistep.py | 721 ---------------- 6 files changed, 790 insertions(+), 733 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py delete mode 100644 src/diffusers/schedulers/scheduling_flow_unipc_multistep.py diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 2563b16c2afc..fbb50280ed22 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -22,7 +22,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel -from ...schedulers import FlowUniPCMultistepScheduler +from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -56,7 +56,7 @@ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = SkyReelsV2Pipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P - >>> pipe.scheduler = FlowUniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) >>> pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." @@ -108,7 +108,7 @@ class SkyReelsV2Pipeline(DiffusionPipeline, WanLoraLoaderMixin): the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. transformer ([`WanTransformer3DModel`]): Conditional Transformer to denoise the input latents. - scheduler ([`FlowUniPCMultistepScheduler`]): + scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. @@ -123,7 +123,7 @@ def __init__( text_encoder: UMT5EncoderModel, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, - scheduler: FlowUniPCMultistepScheduler, + scheduler: FlowMatchUniPCMultistepScheduler, ): super().__init__() diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index abb4e7cabd1b..473d6b96af94 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -25,7 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...video_processor import VideoProcessor from ...models import AutoencoderKLWan, WanTransformer3DModel -from ...schedulers import FlowUniPCMultistepScheduler +from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import ( logging, replace_example_docstring, @@ -112,7 +112,7 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): variant. transformer ([`WanTransformer3DModel`]): Conditional Transformer to denoise the encoded image latents. - scheduler ([`FlowUniPCMultistepScheduler`]): + scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. @@ -129,7 +129,7 @@ def __init__( image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, - scheduler: FlowUniPCMultistepScheduler, + scheduler: FlowMatchUniPCMultistepScheduler, ): super().__init__() diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index 11e2aea8a5d7..3511ca02c139 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -24,7 +24,7 @@ from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel -from ...schedulers import FlowUniPCMultistepScheduler +from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor @@ -145,7 +145,7 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): variant. transformer ([`WanTransformer3DModel`]): Conditional Transformer to denoise the input latents. - scheduler ([`FlowUniPCMultistepScheduler`]): + scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. @@ -162,7 +162,7 @@ def __init__( image_processor: CLIPImageProcessor, transformer: WanTransformer3DModel, vae: AutoencoderKLWan, - scheduler: FlowUniPCMultistepScheduler, + scheduler: FlowMatchUniPCMultistepScheduler, ): super().__init__() diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 7be31bd0821d..3a16e7f96a3d 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -61,7 +61,7 @@ _import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"] _import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"] _import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"] - _import_structure["scheduling_flow_unipc_multistep"] = ["FlowUniPCMultistepScheduler"] + _import_structure["scheduling_flow_match_unipc_multistep"] = ["FlowMatchUniPCMultistepScheduler"] _import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"] _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] @@ -164,7 +164,7 @@ from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler from .scheduling_flow_match_lcm import FlowMatchLCMScheduler - from .scheduling_flow_unipc_multistep import FlowUniPCMultistepScheduler + from .scheduling_flow_match_unipc_multistep import FlowMatchUniPCMultistepScheduler from .scheduling_heun_discrete import HeunDiscreteScheduler from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py new file mode 100644 index 000000000000..446ec2f1362d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -0,0 +1,778 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Converted unipc for flow matching +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin +from diffusers.configuration_utils import register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.schedulers.scheduling_utils import SchedulerOutput +from diffusers.utils import deprecate + + +class FlowMatchUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowMatchUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts the + flow of the diffusion process. + shift (`float`, defaults to 1.0): + Scaling factor for time shifting in flow matching. + use_dynamic_shifting (`bool`, defaults to False): + Whether to use dynamic time shifting based on image resolution. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for sampling. If `None`, default values are used. + mu (`float`, *optional*): + Value for dynamic shifting based on image resolution. Required when `use_dynamic_shifting=True`. + shift (`float`, *optional*): + Scaling factor for time shifting. Overrides config value if provided. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError("you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py deleted file mode 100644 index 3f38ec2de7e2..000000000000 --- a/src/diffusers/schedulers/scheduling_flow_unipc_multistep.py +++ /dev/null @@ -1,721 +0,0 @@ -# Copyright 2024 The SkyReels-V2 Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `FlowUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models, - adapted for flow matching. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - solver_order (`int`, default `2`): - The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` - due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for - unconditional sampling. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts the - flow of the diffusion process. - shift (`float`, defaults to 1.0): - Scaling factor for time shifting in flow matching. - use_dynamic_shifting (`bool`, defaults to False): - Whether to use dynamic time shifting based on image resolution. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - predict_x0 (`bool`, defaults to `True`): - Whether to use the updating algorithm on the predicted x0. - solver_type (`str`, default `bh2`): - Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` - otherwise. - lower_order_final (`bool`, default `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - disable_corrector (`list`, default `[]`): - Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` - and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is - usually disabled during the first few steps. - solver_p (`SchedulerMixin`, default `None`): - Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps, as required by some model families. - final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - ): - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh2") - else: - raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p - self.last_sample = None - self._step_index = None - self._begin_index = None - - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - sigmas (`List[float]`, *optional*): - Custom sigmas to use for sampling. If `None`, default values are used. - mu (`float`, *optional*): - Value for dynamic shifting based on image resolution. Required when `use_dynamic_shifting=True`. - shift (`float`, *optional*): - Scaling factor for time shifting. Overrides config value if provided. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError("you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - self.last_sample = None - if self.solver_p: - self.solver_p.set_timesteps(self.num_inference_steps, device=device) - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - def _sigma_to_alpha_sigma_t(self, sigma): - # Compute alpha, sigma_t from sigma - alpha = torch.sigmoid(-sigma) - sigma_t = torch.sqrt((1 - alpha**2) / alpha**2) - return alpha, sigma_t - - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return mu * t / (mu + (sigma - mu) * t) - - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - Convert the model output to the corresponding type the flow matching framework expects. - - Args: - model_output (`torch.Tensor`): direct output from the model. - sample (`torch.Tensor`, *optional*): current instance of sample being created. - - Returns: - `torch.Tensor`: converted model output for the flow matching framework. - """ - # We dynamically set the scheduler to the correct inference steps - if self.config.prediction_type == "flow_prediction": - sigma = self.sigmas[self._step_index] - t = self.timesteps[self._step_index].to(model_output.device, dtype=model_output.dtype) - t = t / self.config.num_train_timesteps - - # Compute alpha, sigma_t from sigma - alpha, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - alpha = alpha.to(model_output.device, dtype=model_output.dtype) - sigma_t = sigma_t.to(model_output.device, dtype=model_output.dtype) - - if self.predict_x0: - if self.config.thresholding: - model_output = self._threshold_sample(model_output) - x0_pred = model_output - derivative = (sample - alpha * x0_pred) / sigma_t - else: - derivative = model_output - return derivative - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be flow_prediction for {self.__class__}" - ) - - def multistep_uni_p_bh_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - Multistep universal `P` for building the predictor solver. - - Args: - model_output (`torch.Tensor`): - Direct output from the model. - sample (`torch.Tensor`, *optional*): - Current instance of sample being created. - order (`int`, *optional*): - Order of the solver. If `None`, it will be set based on scheduler configuration. - - Returns: - `torch.Tensor`: The predicted sample for the predictor solver. - """ - if order is None: - order = self.config.solver_order - - model_output = self.convert_model_output(model_output, sample=sample) - - # For P(x_{t+1}, x_t, x, x_{t-1} + univeral_coeff * ds_{t-1}) - # DPMSolver only need to save x_{t-1}, x_t, and ds_{t-1} and the higher order. - # We reuse the concept of UniPC and DPMSolver in the `uni_p*_update` function - self.model_outputs.append(model_output.to(sample.dtype)) - self.model_outputs.pop(0) - - time_step = self.timesteps[self._step_index].to(sample.device, model_output.dtype) - prev_time_step = self.timesteps[self._step_index + 1].to(sample.device, model_output.dtype) - - if self._step_index >= len(self.timesteps): - raise ValueError("Requested prediction step cannot advance any further. You cannot advance further.") - - # current_sigma = self.sigmas[self._step_index].to(sample.device, model_output.dtype) - dt = prev_time_step - time_step - - # 1. P(x_{t+1}, x_t, ds_t) - # Define discretized time and compute the time difference - model_output_dagger = model_output - # time_step_array = torch.tensor([1.0, time_step, time_step**2, time_step**3]) - # prev_time_step_array = torch.tensor([1.0, prev_time_step, prev_time_step**2, prev_time_step**3]) - - if order == 1: # predictor with euler steps - if self.config.solver_type == "bh1": - x_t = sample + dt * model_output_dagger - elif self.config.solver_type == "bh2": - x_t = sample + dt * model_output_dagger - else: - self.timestep_list.append(time_step) - self.timestep_list.pop(0) - - # Order matching the UniPC - if 2 <= order <= 3: - current_model_output = model_output_dagger - prev_model_output = self.model_outputs[-2] - - time_coe = dt - - rhos = self.sigmas[self._step_index - 1] / self.sigmas[self._step_index] - rhos = rhos.to(sample.device, model_output.dtype) - - # t -> t + 1 - if order == 2: - # Bh1 - if self.config.solver_type == "bh1": - # Taylor expansion - h_tau = time_coe - h_phi = time_coe - - # 2nd order expansion - x_t = ( - sample - + h_phi * current_model_output - + 0.5 * h_phi**2 * (current_model_output - prev_model_output) / dt - ) - elif self.config.solver_type == "bh2": - # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver - h_t, h_t_1 = time_step, self.timestep_list[-2] - # r = rhos - - # prediction: 2nd order expansion from UniPC paper - h = h_t - h_t_1 - x_t = ( - sample - + h * current_model_output - - 0.5 * h**2 * (current_model_output - prev_model_output) / (h_t - h_t_1) - ) - elif order == 3: - prev_prev_model_output = self.model_outputs[-3] - h_t, h_t_1, h_t_2 = time_step, self.timestep_list[-2], self.timestep_list[-3] - # r, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] - _, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] - r_1 = r_1.to(sample.device, model_output.dtype) - - # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver - if self.config.solver_type == "bh1": - # Taylor expansion - h_tau = time_coe - h_phi = time_coe - h = h_t_1 - h_t_2 - derivative2 = (current_model_output - prev_model_output) / (h_t - h_t_1) - derivative3 = ( - derivative2 - (prev_model_output - prev_prev_model_output) / (h_t_1 - h_t_2) - ) / (h_t - h_t_2) - x_t = ( - sample - + h_tau * current_model_output - + 0.5 * h_tau**2 * derivative2 - + (1.0 / 6.0) * h_tau**3 * derivative3 - ) - elif self.config.solver_type == "bh2": - # From UniC paper: https://github.com/wl-zhao/UniPC - h1 = h_t - h_t_1 - h2 = h_t_1 - h_t_2 - h_left_01 = h_t - h_t_1 - h_left_12 = h_t_1 - h_t_2 - h_left_02 = h_t - h_t_2 - taylor1 = current_model_output - taylor2 = (current_model_output - prev_model_output) / h_left_01 - taylor3 = (taylor2 - (prev_model_output - prev_prev_model_output) / h_left_12) / h_left_02 - x_t = sample + h1 * taylor1 + h1**2 * taylor2 / 2 + h1**2 * h2 * taylor3 / 6 - - else: - raise NotImplementedError(f"Multistep UniCI predict with order {order} is not implemented yet.") - - # The format of predictor solvers in DPM-Solver. - return x_t - - def multistep_uni_c_bh_update( - self, - this_model_output: torch.Tensor, - *args, - last_sample: torch.Tensor = None, - this_sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - Multistep universal `C` for updating the corrector solver. - - Args: - this_model_output (`torch.Tensor`): - Direct output from the model of the current scale. - last_sample (`torch.Tensor`, *optional*): - Sample from the previous scale. - this_sample (`torch.Tensor`, *optional*): - Current sample. - order (`int`, *optional*): - Order of the solver. If `None`, it will be set based on scheduler configuration. - - Returns: - `torch.Tensor`: The updated sample for the corrector solver. - """ - if order is None: - order = self.config.solver_order - # Similar structure as the universal P - # Convert to flow matching format - this_model_output = self.convert_model_output(this_model_output, sample=this_sample) - - if self._step_index > self.num_inference_steps - 1: - prev_time_step = torch.tensor(0.0) - else: - prev_time_step = self.timesteps[self._step_index + 1].to(this_sample.device, this_model_output.dtype) - - time_step = self.timesteps[self._step_index].to(this_sample.device, this_model_output.dtype) - dt = prev_time_step - time_step - - if order == 1: - model_output_processor = this_model_output - # Model output is scaled if we used noise with multiscale - # Corrector - if self.config.solver_type == "bh1": - # Normal euler step to compute corrector (UniC) - x_t = last_sample + dt * model_output_processor - elif self.config.solver_type == "bh2": - # Midpoint method for Heun's 2nd order method - midpoint_model_output = 0.5 * (model_output_processor + this_model_output) - # Runge-Kutta 2nd order - x_t = last_sample + dt * midpoint_model_output - else: # order > 1: - self.timestep_list.append(time_step) - self.timestep_list.pop(0) - self.model_outputs.append(this_model_output.to(last_sample.dtype)) - self.model_outputs.pop(0) - - current_model_output = this_model_output - prev_model_output = self.model_outputs[-2] - - time_coe = dt - - rhos = self.sigmas[self._step_index - 1] / self.sigmas[self._step_index] - rhos = rhos.to(last_sample.device, last_sample.dtype) - - # t -> t + 1 - if order == 2: - # Bh1 - if self.config.solver_type == "bh1": - # Taylor expansion - h_tau = time_coe - h_phi = time_coe - - # 2nd order expansion - x_t = ( - last_sample - + h_phi * current_model_output - + 0.5 * h_phi**2 * (current_model_output - prev_model_output) / dt - ) - elif self.config.solver_type == "bh2": - h_t, h_t_1 = time_step, self.timestep_list[-2] - # r = rhos - h = h_t - h_t_1 - x_t = ( - last_sample - + h * current_model_output - - 0.5 * h**2 * (current_model_output - prev_model_output) / (h_t - h_t_1) - ) - elif order == 3: - prev_prev_model_output = self.model_outputs[-3] - h_t, h_t_1, h_t_2 = time_step, self.timestep_list[-2], self.timestep_list[-3] - # r, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] - _, r_1 = rhos, self.sigmas[self._step_index - 2] / self.sigmas[self._step_index - 1] - r_1 = r_1.to(last_sample.device, last_sample.dtype) - - # Original DPM Solver ++: https://github.com/LuChengTHU/dpm-solver - if self.config.solver_type == "bh1": - # Taylor expansion - h_tau = time_coe - h_phi = time_coe - h = h_t_1 - h_t_2 - derivative2 = (current_model_output - prev_model_output) / (h_t - h_t_1) - derivative3 = (derivative2 - (prev_model_output - prev_prev_model_output) / (h_t_1 - h_t_2)) / ( - h_t - h_t_2 - ) - x_t = ( - last_sample - + h_tau * current_model_output - + 0.5 * h_tau**2 * derivative2 - + (1.0 / 6.0) * h_tau**3 * derivative3 - ) - elif self.config.solver_type == "bh2": - # From UniC paper: https://github.com/wl-zhao/UniPC - h1 = h_t - h_t_1 - h2 = h_t_1 - h_t_2 - h_left_01 = h_t - h_t_1 - h_left_12 = h_t_1 - h_t_2 - h_left_02 = h_t - h_t_2 - taylor1 = current_model_output - taylor2 = (current_model_output - prev_model_output) / h_left_01 - taylor3 = (taylor2 - (prev_model_output - prev_prev_model_output) / h_left_12) / h_left_02 - x_t = last_sample + h1 * taylor1 + h1**2 * taylor2 / 2 + h1**2 * h2 * taylor3 / 6 - else: - raise NotImplementedError(f"Multistep UniCI predict with order {order} is not implemented yet.") - - # The format of corrector solvers in DPM-Solver. - return x_t - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - if self.begin_index is not None: - indices = indices[self.begin_index :] - - if len(indices) == 0: - raise ValueError( - f"could not find timestep {timestep} from self.timesteps, Currently, self.timesteps have shape {self.timesteps.shape}, " - f"and set scale to {self.config.set_scale}" - ) - return indices[0].item() - - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = self.index_for_timestep(timestep) - self._step_index = step_index - - def step( - self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`torch.Tensor` or `int`): - The discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - # Initialize the step_index if performing the first step - if self._step_index is None: - if self.begin_index is None: - self._step_index = 0 - else: - self._step_index = self.begin_index - self._init_step_index(timestep) - - # Upcast sample and model_output to float32 for self.sigmas - sample = sample.to(self.sigmas.dtype) - model_output = model_output.to(self.sigmas.dtype) - - # Apply predicctor (P): x_t -> x_t-1 - if self.config.lower_order_final and self._step_index > self.num_inference_steps - 4: - # For DPM-Solver++(2S), we use lower order solver for the final steps to stabilize the long time inference - # it is equivalent to use our coefficients but change the order - target_order = min(int(self.config.solver_order - 1), 2) - - # 3rd order method + 2nd order + 1st order - if self.config.solver_order > 2 and self._step_index > self.num_inference_steps - 2: - # set order to 1 for the final step - target_order = min(int(target_order - 1), 2) - - # Switch to lower order for DPM-Solver++(2S) in the final steps to stabilize the long time inference - lower_order_predict = self.multistep_uni_p_bh_update( - model_output=model_output, sample=sample, order=target_order - ) - next_sample = lower_order_predict - else: - this_predict = self.multistep_uni_p_bh_update( - model_output=model_output, sample=sample, order=self.config.solver_order - ) - next_sample = this_predict - - # Apply a corrector - if self._step_index not in self.config.disable_corrector: - # UniCPC - # predictor: x_1 -> x_t-1, corrector: x_1 -> x_t-1 - if self.solver_p: - # solver_p_output = self.solver_p.step(model_output, timestep, sample, return_dict=False)[0] - _ = self.solver_p.step(model_output, timestep, sample, return_dict=False)[0] - next_sample = self.multistep_uni_c_bh_update( - this_model_output=model_output, - last_sample=next_sample, - this_sample=sample, - order=self.config.solver_order, - ) - - # update step index - self._step_index += 1 - self.last_sample = sample - - if not return_dict: - return (next_sample,) - - return SchedulerOutput(prev_sample=next_sample) - - def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if timesteps.device != sigmas.device: - timesteps = timesteps.to(sigmas.device) - if timesteps.dtype != torch.int64: - timesteps = timesteps.to(torch.int64) - - schedule_timesteps = self.timesteps - - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps From 4806660df257ae3cc4c5106ed13a6e289d3c3a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 18:30:07 +0300 Subject: [PATCH 008/264] style --- .../skyreels_v2/pipeline_skyreels_v2.py | 8 ++-- .../pipeline_skyreels_v2_diffusion_forcing.py | 45 ++++++++++--------- .../pipeline_skyreels_v2_image_to_video.py | 5 ++- .../scheduling_flow_match_unipc_multistep.py | 18 +++----- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index fbb50280ed22..5182e10e06a6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -420,7 +420,7 @@ def __call__( output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`SkyReelsOutput`] instead of a plain tuple. + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in @@ -440,9 +440,9 @@ def __call__( Examples: Returns: - [`~SkyReelsOutput`] or `tuple`: - If `return_dict` is `True`, [`SkyReelsOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the + first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 473d6b96af94..0576aeb9b979 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -12,34 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +import html +import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np +import ftfy import PIL.Image import torch -import ftfy -import html -import re -from transformers import CLIPVisionModel, CLIPImageProcessor, AutoTokenizer, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...video_processor import VideoProcessor +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler -from ...utils import ( - logging, - replace_example_docstring, -) +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline -from ...loaders import WanLoraLoaderMixin -from ...image_processor import PipelineImageInput from .pipeline_output import SkyReelsV2PipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_ftfy_available(): + import ftfy + EXAMPLE_DOC_STRING = """\ Examples: @@ -55,7 +60,7 @@ ... ) >>> pipe = pipe.to("cuda") - ... ... + >>> ... ``` """ @@ -76,6 +81,7 @@ def prompt_clean(text): text = whitespace_clean(basic_clean(text)) return text + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -92,9 +98,9 @@ def retrieve_latents( class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for video generation with diffusion forcing using SkyReels-V2. - This pipeline supports two main tasks: Text-to-Video (t2v) and Image-to-Video (i2v) - + Pipeline for video generation with diffusion forcing using SkyReels-V2. This pipeline supports two main tasks: + Text-to-Video (t2v) and Image-to-Video (i2v) + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -437,7 +443,7 @@ def interrupt(self): @property def attention_kwargs(self): return self._attention_kwargs - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -546,12 +552,11 @@ def __call__( Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py index 3511ca02c139..d6592b5d5b52 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py @@ -44,6 +44,7 @@ if is_ftfy_available(): import ftfy + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -572,8 +573,8 @@ def __call__( Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index 446ec2f1362d..275fb65f0f90 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -15,18 +15,13 @@ # limitations under the License. import math -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union +from typing import List, Optional, Tuple, Union import numpy as np import torch -from diffusers.configuration_utils import ConfigMixin -from diffusers.configuration_utils import register_to_config -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers -from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.schedulers.scheduling_utils import SchedulerOutput + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput from diffusers.utils import deprecate @@ -105,7 +100,6 @@ def __init__( steps_offset: int = 0, final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" ): - if solver_type not in ["bh1", "bh2"]: if solver_type in ["midpoint", "heun", "logrho"]: self.register_to_config(solver_type="bh2") @@ -677,9 +671,7 @@ def step( self._init_step_index(timestep) use_corrector = ( - self.step_index > 0 - and self.step_index - 1 not in self.disable_corrector - and self.last_sample is not None # pyright: ignore + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore ) model_output_convert = self.convert_model_output(model_output, sample=sample) From 0986e81823f33ac9a1305d467599b29cda0626f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 18:30:39 +0300 Subject: [PATCH 009/264] up --- src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 5182e10e06a6..4e64b7bab4e3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -441,8 +441,8 @@ def __call__( Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the - first element is a list with the generated images and the second element is a list of `bool`s + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ From 6a300f56ce5aba9533009629502833e8a2e90e56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 18:34:05 +0300 Subject: [PATCH 010/264] up --- .../scheduling_flow_match_unipc_multistep.py | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index 275fb65f0f90..6b78a2b5785c 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -115,7 +115,7 @@ def __init__( if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.sigmas = sigmas self.timesteps = sigmas * num_train_timesteps @@ -158,7 +158,6 @@ def set_begin_index(self, begin_index: int = 0): """ self._begin_index = begin_index - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps def set_timesteps( self, num_inference_steps: Union[int, None] = None, @@ -187,14 +186,14 @@ def set_timesteps( raise ValueError("you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] # pyright: ignore + sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + sigmas = self.time_shift(mu, 1.0, sigmas) else: if shift is None: shift = self.config.shift - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 @@ -206,7 +205,7 @@ def set_timesteps( ) timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas) self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) @@ -346,7 +345,7 @@ def multistep_uni_p_bh_update( model_output: torch.Tensor, *args, sample: torch.Tensor = None, - order: int = None, # pyright: ignore + order: int = None, **kwargs, ) -> torch.Tensor: """ @@ -393,7 +392,7 @@ def multistep_uni_p_bh_update( x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] # pyright: ignore + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -406,13 +405,13 @@ def multistep_uni_p_bh_update( rks = [] D1s = [] for i in range(1, order): - si = self.step_index - i # pyright: ignore + si = self.step_index - i mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore + D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) @@ -455,14 +454,14 @@ def multistep_uni_p_bh_update( if self.predict_x0: x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - alpha_t * B_h * pred_res else: x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) # pyright: ignore + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) else: pred_res = 0 x_t = x_t_ - sigma_t * B_h * pred_res @@ -476,7 +475,7 @@ def multistep_uni_c_bh_update( *args, last_sample: torch.Tensor = None, this_sample: torch.Tensor = None, - order: int = None, # pyright: ignore + order: int = None, **kwargs, ) -> torch.Tensor: """ @@ -528,7 +527,7 @@ def multistep_uni_c_bh_update( x_t = this_sample model_t = this_model_output - sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] # pyright: ignore + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1] alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) @@ -541,13 +540,13 @@ def multistep_uni_c_bh_update( rks = [] D1s = [] for i in range(1, order): - si = self.step_index - (i + 1) # pyright: ignore + si = self.step_index - (i + 1) mi = model_output_list[-(i + 1)] alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) lambda_si = torch.log(alpha_si) - torch.log(sigma_si) rk = (lambda_si - lambda_s0) / h rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore + D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) @@ -671,7 +670,7 @@ def step( self._init_step_index(timestep) use_corrector = ( - self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None # pyright: ignore + self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None ) model_output_convert = self.convert_model_output(model_output, sample=sample) @@ -688,10 +687,10 @@ def step( self.timestep_list[i] = self.timestep_list[i + 1] self.model_outputs[-1] = model_output_convert - self.timestep_list[-1] = timestep # pyright: ignore + self.timestep_list[-1] = timestep if self.config.lower_order_final: - this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) # pyright: ignore + this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index) else: this_order = self.config.solver_order @@ -709,7 +708,7 @@ def step( self.lower_order_nums += 1 # upon completion increase step index by one - self._step_index += 1 # pyright: ignore + self._step_index += 1 if not return_dict: return (prev_sample,) From 45e1680c6fb67a76bb25813899c243c749458323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 18:42:43 +0300 Subject: [PATCH 011/264] fix fn name --- .../schedulers/scheduling_flow_match_unipc_multistep.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index 6b78a2b5785c..6275975e64cc 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -189,7 +189,7 @@ def set_timesteps( sigmas = np.linspace(self.sigma_max, self.sigma_min, num_inference_steps + 1).copy()[:-1] if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) + sigmas = self._time_shift_exponential(mu, 1.0, sigmas) else: if shift is None: shift = self.config.shift @@ -266,8 +266,8 @@ def _sigma_to_t(self, sigma): def _sigma_to_alpha_sigma_t(self, sigma): return 1 - sigma, sigma - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def convert_model_output( From c8a0c14b349933beb25d3b13fca4248473f845f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 19:35:31 +0300 Subject: [PATCH 012/264] update import structure for SkyReelsV2 --- src/diffusers/__init__.py | 8 ++++++++ src/diffusers/pipelines/__init__.py | 4 ++-- src/diffusers/pipelines/skyreels_v2/__init__.py | 9 ++++----- .../pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9e172dc3e7a1..6d54f041b250 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -272,6 +272,7 @@ "FlowMatchEulerDiscreteScheduler", "FlowMatchHeunDiscreteScheduler", "FlowMatchLCMScheduler", + "FlowMatchUniPCMultistepScheduler", "HeunDiscreteScheduler", "IPNDMScheduler", "KarrasVeScheduler", @@ -444,6 +445,9 @@ "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", + "SkyreelsV2DiffusionForcingPipeline", + "SkyreelsV2ImageToVideoPipeline", + "SkyreelsV2Pipeline", "StableAudioPipeline", "StableAudioProjectionModel", "StableCascadeCombinedPipeline", @@ -872,6 +876,7 @@ FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FlowMatchLCMScheduler, + FlowMatchUniPCMultistepScheduler, HeunDiscreteScheduler, IPNDMScheduler, KarrasVeScheduler, @@ -1025,6 +1030,9 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, + SkyreelsV2DiffusionForcingPipeline, + SkyreelsV2ImageToVideoPipeline, + SkyreelsV2Pipeline, StableAudioPipeline, StableAudioProjectionModel, StableCascadeCombinedPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 2a9997f26a6e..9fcc3581f194 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -366,7 +366,7 @@ [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2ImageToVideoPipeline", - "SkyReelsV2TextToVideoPipeline", + "SkyReelsV2Pipeline", ] ) try: @@ -834,7 +834,7 @@ from .skyreels_v2 import ( SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, - SkyReelsV2TextToVideoPipeline, + SkyReelsV2Pipeline, ) else: diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py index 9413b0b3339a..2a9d02c7ffcc 100644 --- a/src/diffusers/pipelines/skyreels_v2/__init__.py +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -13,6 +13,7 @@ _dummy_objects = {} _import_structure = {} + try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() @@ -21,21 +22,20 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: + _import_structure["pipeline_skyreels_v2"] = ["SkyReelsV2Pipeline"] _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] - _import_structure["pipeline_skyreels_v2_text_to_video"] = ["SkyReelsV2TextToVideoPipeline"] - if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + from ...utils.dummy_torch_and_transformers_objects import * else: + from .pipeline_skyreels_v2 import SkyReelsV2Pipeline from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline from .pipeline_skyreels_v2_image_to_video import SkyReelsV2ImageToVideoPipeline - from .pipeline_skyreels_v2_text_to_video import SkyReelsV2TextToVideoPipeline else: import sys @@ -49,4 +49,3 @@ for name, value in _dummy_objects.items(): setattr(sys.modules[__name__], name, value) - del _dummy_objects diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 0576aeb9b979..87b8ac53ffe5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -60,7 +60,7 @@ ... ) >>> pipe = pipe.to("cuda") - >>> ... + >>> # TODO ``` """ From 47306b6ae4f2d09aac8ad63c67c3df0f567ef563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 9 May 2025 19:36:04 +0300 Subject: [PATCH 013/264] add SkyreelsV2 pipeline classes with backend requirements --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++ .../dummy_torch_and_transformers_objects.py | 45 +++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 97bc3f317b32..0eb4a8df4131 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1823,6 +1823,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class FlowMatchUniPCMultistepScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HeunDiscreteScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c36c758d8752..da4bd9d640a9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1667,6 +1667,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SkyreelsV2DiffusionForcingPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SkyreelsV2ImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SkyreelsV2Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableAudioPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c5b8da9473e7a2f8ea676fa963c8742d5ccac36f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 10:27:03 +0300 Subject: [PATCH 014/264] up --- .../schedulers/scheduling_flow_match_unipc_multistep.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index 6275975e64cc..ff0f941c8acf 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -1,6 +1,4 @@ # Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. -# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py -# Converted unipc for flow matching # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +25,7 @@ class FlowMatchUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ - `FlowMatchUniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + `FlowMatchUniPCMultistepScheduler` is a ... This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. @@ -315,7 +313,7 @@ def convert_model_output( else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + " `v_prediction` or `flow_prediction` for the FlowMatchUniPCMultistepScheduler." ) if self.config.thresholding: @@ -329,7 +327,7 @@ def convert_model_output( else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + " `v_prediction` or `flow_prediction` for the FlowMatchUniPCMultistepScheduler." ) if self.config.thresholding: From 5835eaa847b54c8200dd30d25b31db783cfda604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 15:07:49 +0300 Subject: [PATCH 015/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 35 +++++-------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 87b8ac53ffe5..919776178b6b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -353,7 +353,6 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -372,16 +371,10 @@ def prepare_latents( latents = latents.to(device=device, dtype=dtype) image = image.unsqueeze(2) - if last_image is None: - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - else: - last_image = last_image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], - dim=2, - ) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( @@ -407,10 +400,7 @@ def prepare_latents( mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) @@ -462,7 +452,6 @@ def __call__( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -614,24 +603,19 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( - device, dtype=torch.float32 - ) + latents, condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, @@ -643,7 +627,6 @@ def __call__( device, generator, latents, - last_image, ) # 6. Denoising loop From 9d2880e1604e6d44394942af42248450420b0a90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 15:08:22 +0300 Subject: [PATCH 016/264] add draft transformer_skyreels_v2.py with a custom WanModel and attention mechanisms --- .../transformers/transformer_skyreels_v2.py | 866 ++++++++++++++++++ 1 file changed, 866 insertions(+) create mode 100644 src/diffusers/models/transformers/transformer_skyreels_v2.py diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py new file mode 100644 index 000000000000..190a70537afb --- /dev/null +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -0,0 +1,866 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch +import torch.amp as amp +import torch.nn as nn +from torch.backends.cuda import sdp_kernel +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin +from diffusers.models.modeling_utils import ModelMixin + +from .attention import flash_attention + + +flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") + +DISABLE_COMPILE = False # get os env + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer( + torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@amp.autocast("cuda", enabled=False) +def rope_apply(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + bs = x.size(0) + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + f, h, w = grid_sizes.tolist() + seq_len = f * h * w + + # precompute multipliers + + x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2)) + freqs_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x = torch.view_as_real(x * freqs_i).flatten(3) + + return x + + +@torch.compile(dynamic=True, disable=DISABLE_COMPILE) +def fast_rms_norm(x, weight, eps): + x = x.float() + x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) + x = x.type_as(x) * weight + return x + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return fast_rms_norm(x, self.weight, self.eps) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x) + + +class WanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + self._flag_ar_attention = False + + def set_ar_attention(self): + self._flag_ar_attention = True + + def forward(self, x, grid_sizes, freqs, block_mask): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + x = x.to(self.q.weight.dtype) + q, k, v = qkv_fn(x) + + if not self._flag_ar_attention: + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + x = flash_attention(q=q, k=k, v=v, window_size=self.window_size) + else: + q = rope_apply(q, grid_sizes, freqs) + k = rope_apply(k, grid_sizes, freqs) + q = q.to(torch.bfloat16) + k = k.to(torch.bfloat16) + v = v.to(torch.bfloat16) + + with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + x = ( + torch.nn.functional.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanT2VCrossAttention(WanSelfAttention): + def forward(self, x, context): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + super().__init__(dim, num_heads, window_size, qk_norm, eps) + + self.k_img = nn.Linear(dim, dim) + self.v_img = nn.Linear(dim, dim) + # self.alpha = nn.Parameter(torch.zeros((1, ))) + self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) + v_img = self.v_img(context_img).view(b, -1, n, d) + img_x = flash_attention(q, k_img, v_img) + # compute attention + x = flash_attention(q, k, v) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x = self.o(x) + return x + + +WAN_CROSSATTENTION_CLASSES = { + "t2v_cross_attn": WanT2VCrossAttention, + "i2v_cross_attn": WanI2VCrossAttention, +} + + +def mul_add(x, y, z): + return x.float() + y.float() * z.float() + + +def mul_add_add(x, y, z): + return x.float() * (1 + y) + z + + +mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE) +mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE) + + +class WanAttentionBlock(nn.Module): + def __init__( + self, + cross_attn_type, + dim, + ffn_dim, + num_heads, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=False, + eps=1e-6, + ): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def set_ar_attention(self): + self.self_attn.set_ar_attention() + + def forward( + self, + x, + e, + grid_sizes, + freqs, + context, + block_mask, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + if e.dim() == 3: + modulation = self.modulation # 1, 6, dim + with amp.autocast("cuda", dtype=torch.float32): + e = (modulation + e).chunk(6, dim=1) + elif e.dim() == 4: + modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim + with amp.autocast("cuda", dtype=torch.float32): + e = (modulation + e).chunk(6, dim=1) + e = [ei.squeeze(1) for ei in e] + + # self-attention + out = mul_add_add_compile(self.norm1(x), e[1], e[0]) + y = self.self_attn(out, grid_sizes, freqs, block_mask) + with amp.autocast("cuda", dtype=torch.float32): + x = mul_add_compile(x, y, e[2]) + + # cross-attention & ffn function + def cross_attn_ffn(x, context, e): + dtype = context.dtype + x = x + self.cross_attn(self.norm3(x.to(dtype)), context) + y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype)) + with amp.autocast("cuda", dtype=torch.float32): + x = mul_add_compile(x, y, e[5]) + return x + + x = cross_attn_ffn(x, context, e) + return x.to(torch.bfloat16) + + +class Head(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + with amp.autocast("cuda", dtype=torch.float32): + if e.dim() == 2: + modulation = self.modulation # 1, 2, dim + e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + + elif e.dim() == 3: + modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim + e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) + e = [ei.squeeze(1) for ei in e] + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) + return x + + +class MLPProj(torch.nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.LayerNorm(in_dim), + torch.nn.Linear(in_dim, in_dim), + torch.nn.GELU(), + torch.nn.Linear(in_dim, out_dim), + torch.nn.LayerNorm(out_dim), + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + inject_sample_info=False, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + self.num_frame_per_block = 1 + self.flag_causal_attention = False + self.block_mask = None + self.enable_teacache = False + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + if inject_sample_info: + self.fps_embedding = nn.Embedding(2, dim) + self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" + self.blocks = nn.ModuleList( + [ + WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + for _ in range(num_layers) + ] + ) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat( + [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], + dim=1, + ) + + if model_type == "i2v": + self.img_emb = MLPProj(1280, dim) + + self.gradient_checkpointing = False + + self.cpu_offloading = False + + self.inject_sample_info = inject_sample_info + # initialize weights + self.init_weights() + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def zero_init_i2v_cross_attn(self): + print("zero init i2v cross attn") + for i in range(self.num_layers): + self.blocks[i].cross_attn.v_img.weight.data.zero_() + self.blocks[i].cross_attn.v_img.bias.data.zero_() + + @staticmethod + def _prepare_blockwise_causal_attn_mask( + device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 + ) -> BlockMask: + """ + we will divide the token sequence into the following format [1 latent frame] [1 latent frame] ... [1 latent + frame] We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device) + + for tmp in frame_indices: + ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block + + def attention_mask(b, h, q_idx, kv_idx): + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask + + block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device, + ) + + return block_mask + + def initialize_teacache( + self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir="" + ): + self.enable_teacache = enable_teacache + print("using teacache") + self.cnt = 0 + self.num_steps = num_steps + self.teacache_thresh = teacache_thresh + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.use_ref_steps = use_ret_steps + if "I2V" in ckpt_dir: + if use_ret_steps: + if "540P" in ckpt_dir: + self.coefficients = [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01] + if "720P" in ckpt_dir: + self.coefficients = [8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02] + self.ret_steps = 5 * 2 + self.cutoff_steps = num_steps * 2 + else: + if "540P" in ckpt_dir: + self.coefficients = [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01] + if "720P" in ckpt_dir: + self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + self.ret_steps = 1 * 2 + self.cutoff_steps = num_steps * 2 - 2 + else: + if use_ret_steps: + if "1.3B" in ckpt_dir: + self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02] + if "14B" in ckpt_dir: + self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01] + self.ret_steps = 5 * 2 + self.cutoff_steps = num_steps * 2 + else: + if "1.3B" in ckpt_dir: + self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01] + if "14B" in ckpt_dir: + self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + self.ret_steps = 1 * 2 + self.cutoff_steps = num_steps * 2 - 2 + + def forward(self, x, t, context, clip_fea=None, y=None, fps=None): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + clip_fea (Tensor, *optional*): + CLIP image features for image-to-video mode + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == "i2v": + assert clip_fea is not None and y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = torch.cat([x, y], dim=1) + + # embeddings + x = self.patch_embedding(x) + grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long) + x = x.flatten(2).transpose(1, 2) + + if self.flag_causal_attention: + frame_num = grid_sizes[0] + height = grid_sizes[1] + width = grid_sizes[2] + block_num = frame_num // self.num_frame_per_block + range_tensor = torch.arange(block_num).view(-1, 1) + range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() + casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f + casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) + casual_mask = casual_mask.repeat(1, height, width, 1, height, width) + casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) + self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0) + + # time embeddings + with amp.autocast("cuda", dtype=torch.float32): + if t.dim() == 2: + b, f = t.shape + _flag_df = True + else: + _flag_df = False + + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) + ) # b, dim + e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim + + if self.inject_sample_info: + fps = torch.tensor(fps, dtype=torch.long, device=device) + + fps_emb = self.fps_embedding(fps).float() + if _flag_df: + e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) + else: + e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) + + if _flag_df: + e = e.view(b, f, 1, 1, self.dim) + e0 = e0.view(b, f, 1, 1, 6, self.dim) + e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) + e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) + e0 = e0.transpose(1, 2).contiguous() + + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context = self.text_embedding(context) + + if clip_fea is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.concat([context_clip, context], dim=1) + + # arguments + kwargs = { + "e": e0, + "grid_sizes": grid_sizes, + "freqs": self.freqs, + "context": context, + "block_mask": self.block_mask, + } + if self.enable_teacache: + modulated_inp = e0 if self.use_ref_steps else e + # teacache + if self.cnt % 2 == 0: # even -> condition + self.is_even = True + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_even += rescale_func( + ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance_even < self.teacache_thresh: + should_calc_even = False + else: + should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: # odd -> unconditon + self.is_even = False + if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + else: + rescale_func = np.poly1d(self.coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func( + ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: + should_calc_odd = False + else: + should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + + if self.enable_teacache: + if self.is_even: + if not should_calc_even: + x += self.previous_residual_even + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_even = x - ori_x + else: + if not should_calc_odd: + x += self.previous_residual_odd + else: + ori_x = x.clone() + for block in self.blocks: + x = block(x, **kwargs) + self.previous_residual_odd = x - ori_x + + self.cnt += 1 + if self.cnt >= self.num_steps: + self.cnt = 0 + else: + for block in self.blocks: + x = block(x, **kwargs) + + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + return x.float() + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + + c = self.out_dim + bs = x.shape[0] + x = x.view(bs, *grid_sizes, *self.patch_size, c) + x = torch.einsum("bfhwpqrc->bcfphqwr", x) + x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + + return x + + def set_ar_attention(self, causal_block_size): + self.num_frame_per_block = causal_block_size + self.flag_causal_attention = True + for block in self.blocks: + block.set_ar_attention() + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + if self.inject_sample_info: + nn.init.normal_(self.fps_embedding.weight, std=0.02) + + for m in self.fps_projection.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + nn.init.zeros_(self.fps_projection[-1].weight) + nn.init.zeros_(self.fps_projection[-1].bias) + + # init output layer + nn.init.zeros_(self.head.head.weight) From 2c0586e319b6005c913a846ed3b47451767a21d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 15:55:03 +0300 Subject: [PATCH 017/264] up --- src/diffusers/__init__.py | 2 ++ src/diffusers/models/__init__.py | 2 ++ src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_skyreels_v2.py | 30 +++++++++++-------- .../pipeline_skyreels_v2_diffusion_forcing.py | 6 ++-- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 6d54f041b250..1f15f10883fc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -215,6 +215,7 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", + "SkyReelsV2Transformer3DModel", ] ) _import_structure["optimization"] = [ @@ -804,6 +805,7 @@ SD3ControlNetModel, SD3MultiControlNetModel, SD3Transformer2DModel, + SkyReelsV2Transformer3DModel, SparseControlNetModel, StableAudioDiTModel, T2IAdapter, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 58322800332a..c691076beee6 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -89,6 +89,7 @@ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] + _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -173,6 +174,7 @@ PriorTransformer, SanaTransformer2DModel, SD3Transformer2DModel, + SkyReelsV2Transformer3DModel, StableAudioDiTModel, T5FilmDecoder, Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 86094104bd1c..c90b8e0ecb95 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -30,5 +30,6 @@ from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel + from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 190a70537afb..d1aeb4218bec 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,20 +13,25 @@ # limitations under the License. import math +from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch -import torch.amp as amp import torch.nn as nn -from torch.backends.cuda import sdp_kernel -from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +import torch.nn.functional as F -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin -from diffusers.models.modeling_utils import ModelMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..attention import FeedForward +from ..attention_processor import Attention +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm -from .attention import flash_attention +logger = logging.get_logger(__name__) # pylint: disable=invalid-name flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") @@ -408,15 +413,16 @@ def forward(self, image_embeds): return clip_extra_context_tokens -class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin): +class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" Wan diffusion backbone supporting both text-to-video and image-to-video. """ - ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] - _no_split_modules = ["WanAttentionBlock"] - _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] + _no_split_modules = ["WanTransformerBlock"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config def __init__( diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 919776178b6b..a5e12df55a10 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -24,7 +24,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -116,7 +116,7 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): the [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) variant. - transformer ([`WanTransformer3DModel`]): + transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -133,7 +133,7 @@ def __init__( text_encoder: UMT5EncoderModel, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, - transformer: WanTransformer3DModel, + transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, ): From 52590ea2ace7eec8a2bc28c6259346431850a8c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 16:06:27 +0300 Subject: [PATCH 018/264] split i2v and t2v pipes for diffusion forcing --- src/diffusers/__init__.py | 4 +- src/diffusers/models/__init__.py | 2 +- .../transformers/transformer_skyreels_v2.py | 16 +- src/diffusers/pipelines/__init__.py | 2 + .../pipelines/skyreels_v2/__init__.py | 8 +- .../pipeline_skyreels_v2_diffusion_forcing.py | 2 +- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 709 ++++++++++++++++++ ...o_video.py => pipeline_skyreels_v2_i2v.py} | 0 src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + 10 files changed, 758 insertions(+), 15 deletions(-) create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py rename src/diffusers/pipelines/skyreels_v2/{pipeline_skyreels_v2_image_to_video.py => pipeline_skyreels_v2_i2v.py} (100%) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1f15f10883fc..686c34e1daad 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -198,6 +198,7 @@ "SD3ControlNetModel", "SD3MultiControlNetModel", "SD3Transformer2DModel", + "SkyReelsV2Transformer3DModel", "SparseControlNetModel", "StableAudioDiTModel", "StableCascadeUNet", @@ -215,7 +216,6 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", - "SkyReelsV2Transformer3DModel", ] ) _import_structure["optimization"] = [ @@ -446,6 +446,7 @@ "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", + "SkyreelsV2DiffusionForcingImageToVideoPipeline", "SkyreelsV2DiffusionForcingPipeline", "SkyreelsV2ImageToVideoPipeline", "SkyreelsV2Pipeline", @@ -1032,6 +1033,7 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, + SkyreelsV2DiffusionForcingImageToVideoPipeline, SkyreelsV2DiffusionForcingPipeline, SkyreelsV2ImageToVideoPipeline, SkyreelsV2Pipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index c691076beee6..c59d01200cc4 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -87,9 +87,9 @@ _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] + _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] - _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d1aeb4218bec..b11d64af65a6 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,22 +13,20 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union +import numpy as np import torch +import torch.amp as amp import torch.nn as nn -import torch.nn.functional as F +from torch.backends.cuda import sdp_kernel +from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -from ..attention import FeedForward -from ..attention_processor import Attention +from ...utils import logging from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed -from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm +from .attention import flash_attention logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -37,8 +35,6 @@ DISABLE_COMPILE = False # get os env -__all__ = ["WanModel"] - def sinusoidal_embedding_1d(dim, position): # preprocess diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 9fcc3581f194..79d14f1ee2b4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -365,6 +365,7 @@ _import_structure["skyreels_v2"].extend( [ "SkyReelsV2DiffusionForcingPipeline", + "SkyReelsV2DiffusionForcingImageToVideoPipeline", "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", ] @@ -832,6 +833,7 @@ ) from .skyreels_v2 import ( + SkyReelsV2DiffusionForcingImageToVideoPipeline, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py index 2a9d02c7ffcc..602f034ca1e2 100644 --- a/src/diffusers/pipelines/skyreels_v2/__init__.py +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -24,7 +24,10 @@ else: _import_structure["pipeline_skyreels_v2"] = ["SkyReelsV2Pipeline"] _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"] - _import_structure["pipeline_skyreels_v2_image_to_video"] = ["SkyReelsV2ImageToVideoPipeline"] + _import_structure["pipeline_skyreels_v2_diffusion_forcing_i2v"] = [ + "SkyReelsV2DiffusionForcingImageToVideoPipeline" + ] + _import_structure["pipeline_skyreels_v2_i2v"] = ["SkyReelsV2ImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -35,7 +38,8 @@ else: from .pipeline_skyreels_v2 import SkyReelsV2Pipeline from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline - from .pipeline_skyreels_v2_image_to_video import SkyReelsV2ImageToVideoPipeline + from .pipeline_skyreels_v2_diffusion_forcing_i2v import SkyReelsV2DiffusionForcingImageToVideoPipeline + from .pipeline_skyreels_v2_i2v import SkyReelsV2ImageToVideoPipeline else: import sys diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a5e12df55a10..6175ab918882 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -517,7 +517,7 @@ def __call__( output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`SkyReelsV2DiffusionForcingPipelineOutput`] instead of a plain tuple. + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py new file mode 100644 index 000000000000..094c166eda98 --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -0,0 +1,709 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import ftfy +import PIL.Image +import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import FlowMatchUniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2DiffusionForcingImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Load the pipeline + >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( + ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # TODO + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + """ + Pipeline for video generation with diffusion forcing using SkyReels-V2. This pipeline supports two main tasks: + Text-to-Video (t2v) and Image-to-Video (i2v) + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchUniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + image_encoder=image_encoder, + transformer=transformer, + scheduler=scheduler, + image_processor=image_processor, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + image: PipelineImageInput, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + height: int = 480, + width: int = 832, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + overlap_history: Optional[int] = None, + shift: float = 1.0, + addnoise_condition: float = 0.0, + base_num_frames: int = 97, + ar_step: int = 5, + causal_block_size: Optional[int] = None, + fps: int = 24, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + shift (`float`, *optional*, defaults to `5.0`): + The shift of the flow. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + Examples: + + Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + image_embeds = self.encode_image(image, device) + + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py similarity index 100% rename from src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_image_to_video.py rename to src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 0eb4a8df4131..d9cd19bbfe2c 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -910,6 +910,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class SkyReelsV2Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class SparseControlNetModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index da4bd9d640a9..2b5f3ba86dd2 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1667,6 +1667,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SkyreelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SkyreelsV2DiffusionForcingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From c50fcad950a699c693112e849dbd29ef2f2892b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 10 May 2025 18:05:27 +0300 Subject: [PATCH 019/264] Refactors SkyReelsV2 attention and normalizations Replaces custom attention implementations with `SkyReelsV2AttnProcessor2_0` and the standard `Attention` module. Updates `WanAttentionBlock` to use `FP32LayerNorm` and `FeedForward`. Removes the `model_type` parameter, simplifying model architecture and attention block initialization. --- .../transformers/transformer_skyreels_v2.py | 106 ++++++++++-------- 1 file changed, 62 insertions(+), 44 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index b11d64af65a6..4aea78818836 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,19 +13,24 @@ # limitations under the License. import math +from typing import Optional import numpy as np import torch import torch.amp as amp import torch.nn as nn +import torch.nn.functional as F from torch.backends.cuda import sdp_kernel from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import logging +from ..attention import FeedForward +from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm from .attention import flash_attention @@ -126,24 +131,12 @@ def forward(self, x): return super().forward(x) -class WanSelfAttention(nn.Module): - def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): - assert dim % num_heads == 0 - super().__init__() - self.dim = dim - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.eps = eps - - # layers - self.q = nn.Linear(dim, dim) - self.k = nn.Linear(dim, dim) - self.v = nn.Linear(dim, dim) - self.o = nn.Linear(dim, dim) - self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() +class SkyReelsV2AttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." + ) self._flag_ar_attention = False @@ -196,7 +189,7 @@ def qkv_fn(x): return x -class WanT2VCrossAttention(WanSelfAttention): +class WanT2VCrossAttention(SkyReelsV2AttnProcessor2_0): def forward(self, x, context): r""" Args: @@ -220,7 +213,7 @@ def forward(self, x, context): return x -class WanI2VCrossAttention(WanSelfAttention): +class WanI2VCrossAttention(SkyReelsV2AttnProcessor2_0): def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): super().__init__(dim, num_heads, window_size, qk_norm, eps) @@ -279,14 +272,14 @@ def mul_add_add(x, y, z): class WanAttentionBlock(nn.Module): def __init__( self, - cross_attn_type, - dim, - ffn_dim, - num_heads, + dim: int, + ffn_dim: int, + num_heads: int, window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=False, - eps=1e-6, + qk_norm: str = "rms_norm", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, ): super().__init__() self.dim = dim @@ -297,16 +290,46 @@ def __init__( self.cross_attn_norm = cross_attn_norm self.eps = eps - # layers - self.norm1 = WanLayerNorm(dim, eps) - self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) - self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() - self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim, num_heads, (-1, -1), qk_norm, eps) - self.norm2 = WanLayerNorm(dim, eps) - self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = Attention( + window_size=window_size, + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + processor=SkyReelsV2AttnProcessor2_0(), + ) - # modulation - self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + # 2. Cross-attention + self.attn2 = Attention( + window_size=(-1, -1), + query_dim=dim, + heads=num_heads, + kv_heads=num_heads, + dim_head=dim // num_heads, + qk_norm=qk_norm, + eps=eps, + bias=True, + cross_attention_dim=None, + out_bias=True, + added_kv_proj_dim=added_kv_proj_dim, + added_proj_bias=True, + pre_only=False, + processor=SkyReelsV2AttnProcessor2_0(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps) if cross_attn_norm else nn.Identity() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def set_ar_attention(self): self.self_attn.set_ar_attention() @@ -423,7 +446,6 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr @register_to_config def __init__( self, - model_type="t2v", patch_size=(1, 2, 2), text_len=512, in_dim=16, @@ -478,9 +500,6 @@ def __init__( super().__init__() - assert model_type in ["t2v", "i2v"] - self.model_type = model_type - self.patch_size = patch_size self.text_len = text_len self.in_dim = in_dim @@ -512,10 +531,9 @@ def __init__( self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6)) # blocks - cross_attn_type = "t2v_cross_attn" if model_type == "t2v" else "i2v_cross_attn" self.blocks = nn.ModuleList( [ - WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) for _ in range(num_layers) ] ) @@ -531,8 +549,8 @@ def __init__( dim=1, ) - if model_type == "i2v": - self.img_emb = MLPProj(1280, dim) + # if model_type == "i2v": + # self.img_emb = MLPProj(1280, dim) self.gradient_checkpointing = False From d0c71fd87afddc16e5db35aadae91ce1fc49901d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 11 May 2025 17:23:49 +0300 Subject: [PATCH 020/264] Add SkyReelsV2 image and time-text embeddings Introduces new classes `SkyReelsV2ImageEmbedding` and `SkyReelsV2TimeTextImageEmbedding` for enhanced image and time-text processing. Refactors the `SkyReelsV2Transformer3DModel` to integrate these embeddings, updating the constructor parameters for better clarity and functionality. Removes unused classes and methods to streamline the codebase. --- .../transformers/transformer_skyreels_v2.py | 394 +++++++++--------- 1 file changed, 199 insertions(+), 195 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 4aea78818836..154de4443005 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch @@ -31,6 +31,7 @@ from ..cache_utils import CacheMixin from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from .attention import flash_attention @@ -41,28 +42,6 @@ DISABLE_COMPILE = False # get os env -def sinusoidal_embedding_1d(dim, position): - # preprocess - assert dim % 2 == 0 - half = dim // 2 - position = position.type(torch.float64) - - # calculation - sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) - x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) - return x - - -@amp.autocast("cuda", enabled=False) -def rope_params(max_seq_len, dim, theta=10000): - assert dim % 2 == 0 - freqs = torch.outer( - torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) - ) - freqs = torch.polar(torch.ones_like(freqs), freqs) - return freqs - - @amp.autocast("cuda", enabled=False) def rope_apply(x, grid_sizes, freqs): n, c = x.size(2), x.size(3) // 2 @@ -93,44 +72,6 @@ def rope_apply(x, grid_sizes, freqs): return x -@torch.compile(dynamic=True, disable=DISABLE_COMPILE) -def fast_rms_norm(x, weight, eps): - x = x.float() - x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps) - x = x.type_as(x) * weight - return x - - -class WanRMSNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return fast_rms_norm(x, self.weight, self.eps) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) - - -class WanLayerNorm(nn.LayerNorm): - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return super().forward(x) - - class SkyReelsV2AttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): @@ -251,12 +192,6 @@ def forward(self, x, context): return x -WAN_CROSSATTENTION_CLASSES = { - "t2v_cross_attn": WanT2VCrossAttention, - "i2v_cross_attn": WanI2VCrossAttention, -} - - def mul_add(x, y, z): return x.float() + y.float() * z.float() @@ -268,8 +203,114 @@ def mul_add_add(x, y, z): mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE) mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE) +class SkyReelsV2ImageEmbedding(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + +class SkyReelsV2TimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + +class SkyReelsV2RotaryPosEmbed(nn.Module): + def __init__( + self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + freqs = [] + for dim in [t_dim, h_dim, w_dim]: + freq = get_1d_rotary_pos_embed( + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 + ) + freqs.append(freq) + self.freqs = torch.cat(freqs, dim=1) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + freqs = self.freqs.to(hidden_states.device) + freqs = freqs.split_with_sizes( + [ + self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), + self.attention_head_dim // 6, + self.attention_head_dim // 6, + ], + dim=1, + ) + + freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) + return freqs -class WanAttentionBlock(nn.Module): +class SkyReelsV2TransformerBlock(nn.Module): def __init__( self, dim: int, @@ -381,21 +422,6 @@ def cross_attn_ffn(x, context, e): class Head(nn.Module): - def __init__(self, dim, out_dim, patch_size, eps=1e-6): - super().__init__() - self.dim = dim - self.out_dim = out_dim - self.patch_size = patch_size - self.eps = eps - - # layers - out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) - self.head = nn.Linear(dim, out_dim) - - # modulation - self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) - def forward(self, x, e): r""" Args: @@ -415,26 +441,47 @@ def forward(self, x, e): return x -class MLPProj(torch.nn.Module): - def __init__(self, in_dim, out_dim): - super().__init__() - - self.proj = torch.nn.Sequential( - torch.nn.LayerNorm(in_dim), - torch.nn.Linear(in_dim, in_dim), - torch.nn.GELU(), - torch.nn.Linear(in_dim, out_dim), - torch.nn.LayerNorm(out_dim), - ) - - def forward(self, image_embeds): - clip_extra_context_tokens = self.proj(image_embeds) - return clip_extra_context_tokens - - class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" - Wan diffusion backbone supporting both text-to-video and image-to-video. + A Transformer model for video-like data used in the Wan-based SkyReels-V2 model. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `16`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `8192`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `32`): + The number of layers of transformer blocks to use. + window_size (`Tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`str`, *optional*, defaults to `"rms_norm"`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + inject_sample_info (`bool`, defaults to `False`): + Whether to inject sample information into the model. + image_dim (`int`, *optional*): + The dimension of the image embeddings. + added_kv_proj_dim (`int`, *optional*): + The dimension of the added key/value projection. + rope_max_seq_len (`int`, defaults to `1024`): + The maximum sequence length for the rotary embeddings. + pos_embed_seq_len (`int`, *optional*): + The sequence length for the positional embeddings. """ _supports_gradient_checkpointing = True @@ -446,119 +493,76 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr @register_to_config def __init__( self, - patch_size=(1, 2, 2), - text_len=512, - in_dim=16, - dim=2048, - ffn_dim=8192, - freq_dim=256, - text_dim=4096, - out_dim=16, - num_heads=16, - num_layers=32, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=True, - inject_sample_info=False, - eps=1e-6, + patch_size: Tuple[int] = (1, 2, 2), + attention_head_dim: int = 128, + in_channels: int = 16, + ffn_dim: int = 8192, + freq_dim: int = 256, + text_dim: int = 4096, + out_channels: int = 16, + num_attention_heads: int = 16, + num_layers: int = 32, + window_size: Tuple[int, int] = (-1, -1), + qk_norm: Optional[str] = "rms_norm", + cross_attn_norm: bool = True, + inject_sample_info: bool = False, + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, ): - r""" - Initialize the diffusion model backbone. - - Args: - model_type (`str`, *optional*, defaults to 't2v'): - Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) - patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): - 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) - text_len (`int`, *optional*, defaults to 512): - Fixed length for text embeddings - in_dim (`int`, *optional*, defaults to 16): - Input video channels (C_in) - dim (`int`, *optional*, defaults to 2048): - Hidden dimension of the transformer - ffn_dim (`int`, *optional*, defaults to 8192): - Intermediate dimension in feed-forward network - freq_dim (`int`, *optional*, defaults to 256): - Dimension for sinusoidal time embeddings - text_dim (`int`, *optional*, defaults to 4096): - Input dimension for text embeddings - out_dim (`int`, *optional*, defaults to 16): - Output video channels (C_out) - num_heads (`int`, *optional*, defaults to 16): - Number of attention heads - num_layers (`int`, *optional*, defaults to 32): - Number of transformer blocks - window_size (`tuple`, *optional*, defaults to (-1, -1)): - Window size for local attention (-1 indicates global attention) - qk_norm (`bool`, *optional*, defaults to True): - Enable query/key normalization - cross_attn_norm (`bool`, *optional*, defaults to False): - Enable cross-attention normalization - eps (`float`, *optional*, defaults to 1e-6): - Epsilon value for normalization layers - """ - super().__init__() - self.patch_size = patch_size - self.text_len = text_len - self.in_dim = in_dim - self.dim = dim - self.ffn_dim = ffn_dim - self.freq_dim = freq_dim - self.text_dim = text_dim - self.out_dim = out_dim - self.num_heads = num_heads - self.num_layers = num_layers - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + self.num_frame_per_block = 1 self.flag_causal_attention = False self.block_mask = None self.enable_teacache = False + self.inject_sample_info = inject_sample_info - # embeddings - self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) - self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) - - self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) - - if inject_sample_info: - self.fps_embedding = nn.Embedding(2, dim) - self.fps_projection = nn.Sequential(nn.Linear(dim, dim), nn.SiLU(), nn.Linear(dim, dim * 6)) + # 1. Patch & position embedding + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + + # 2. Condition embeddings + # image_embedding_dim=1280 for I2V model + self.condition_embedder = SkyReelsV2TimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) - # blocks + # 3. Transformer blocks self.blocks = nn.ModuleList( [ - WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) + SkyReelsV2TransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, + window_size, # TODO: check + added_kv_proj_dim # TODO: check + ) for _ in range(num_layers) ] ) - # head - self.head = Head(dim, out_dim, patch_size, eps) - - # buffers (don't use register_buffer otherwise dtype will be changed in to()) - assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 - d = dim // num_heads - self.freqs = torch.cat( - [rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], - dim=1, - ) - - # if model_type == "i2v": - # self.img_emb = MLPProj(1280, dim) + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) self.gradient_checkpointing = False - self.cpu_offloading = False + if inject_sample_info: + self.fps_embedding = nn.Embedding(2, inner_dim) + self.fps_projection = nn.Sequential(nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6)) - self.inject_sample_info = inject_sample_info - # initialize weights - self.init_weights() + # TODO: Say: Initializing suggested by the original repo? + # self.init_weights() def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value From f318efa9bd6416238b909937df919a92072c0dee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 11 May 2025 17:25:17 +0300 Subject: [PATCH 021/264] up --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 154de4443005..3be544293368 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -29,9 +29,9 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from .attention import flash_attention From 9688a829961541f8ed50f9b31e18b683fc11c453 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 11 May 2025 20:23:38 +0300 Subject: [PATCH 022/264] Refactors the `SkyReelsV2Transformer3DModel` by removing unused methods and begin reorganizing the forward pass. --- .../transformers/transformer_skyreels_v2.py | 245 +++++++++--------- 1 file changed, 127 insertions(+), 118 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 3be544293368..22d16f703ea0 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -564,98 +564,6 @@ def __init__( # TODO: Say: Initializing suggested by the original repo? # self.init_weights() - def _set_gradient_checkpointing(self, module, value=False): - self.gradient_checkpointing = value - - def zero_init_i2v_cross_attn(self): - print("zero init i2v cross attn") - for i in range(self.num_layers): - self.blocks[i].cross_attn.v_img.weight.data.zero_() - self.blocks[i].cross_attn.v_img.bias.data.zero_() - - @staticmethod - def _prepare_blockwise_causal_attn_mask( - device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 - ) -> BlockMask: - """ - we will divide the token sequence into the following format [1 latent frame] [1 latent frame] ... [1 latent - frame] We use flexattention to construct the attention mask - """ - total_length = num_frames * frame_seqlen - - # we do right padding to get to a multiple of 128 - padded_length = math.ceil(total_length / 128) * 128 - total_length - - ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) - - # Block-wise causal mask will attend to all elements that are before the end of the current chunk - frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device) - - for tmp in frame_indices: - ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block - - def attention_mask(b, h, q_idx, kv_idx): - return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) - # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask - - block_mask = create_block_mask( - attention_mask, - B=None, - H=None, - Q_LEN=total_length + padded_length, - KV_LEN=total_length + padded_length, - _compile=False, - device=device, - ) - - return block_mask - - def initialize_teacache( - self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir="" - ): - self.enable_teacache = enable_teacache - print("using teacache") - self.cnt = 0 - self.num_steps = num_steps - self.teacache_thresh = teacache_thresh - self.accumulated_rel_l1_distance_even = 0 - self.accumulated_rel_l1_distance_odd = 0 - self.previous_e0_even = None - self.previous_e0_odd = None - self.previous_residual_even = None - self.previous_residual_odd = None - self.use_ref_steps = use_ret_steps - if "I2V" in ckpt_dir: - if use_ret_steps: - if "540P" in ckpt_dir: - self.coefficients = [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01] - if "720P" in ckpt_dir: - self.coefficients = [8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02] - self.ret_steps = 5 * 2 - self.cutoff_steps = num_steps * 2 - else: - if "540P" in ckpt_dir: - self.coefficients = [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01] - if "720P" in ckpt_dir: - self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] - self.ret_steps = 1 * 2 - self.cutoff_steps = num_steps * 2 - 2 - else: - if use_ret_steps: - if "1.3B" in ckpt_dir: - self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02] - if "14B" in ckpt_dir: - self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01] - self.ret_steps = 5 * 2 - self.cutoff_steps = num_steps * 2 - else: - if "1.3B" in ckpt_dir: - self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01] - if "14B" in ckpt_dir: - self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] - self.ret_steps = 1 * 2 - self.cutoff_steps = num_steps * 2 - 2 - def forward(self, x, t, context, clip_fea=None, y=None, fps=None): r""" Forward pass through the diffusion model @@ -678,6 +586,28 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): List[Tensor]: List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] """ + + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + if self.model_type == "i2v": assert clip_fea is not None and y is not None # params @@ -818,36 +748,26 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): for block in self.blocks: x = block(x, **kwargs) - x = self.head(x, e) - - # unpatchify - x = self.unpatchify(x, grid_sizes) - - return x.float() + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) - def unpatchify(self, x, grid_sizes): - r""" - Reconstruct video tensors from patch embeddings. + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) - Args: - x (List[Tensor]): - List of patchified features, each with shape [L, C_out * prod(patch_size)] - grid_sizes (Tensor): - Original spatial-temporal grid dimensions before patching, - shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) - Returns: - List[Tensor]: - Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] - """ - - c = self.out_dim - bs = x.shape[0] - x = x.view(bs, *grid_sizes, *self.patch_size, c) - x = torch.einsum("bfhwpqrc->bcfphqwr", x) - x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)]) + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - return x + return x.float() def set_ar_attention(self, causal_block_size): self.num_frame_per_block = causal_block_size @@ -855,6 +775,89 @@ def set_ar_attention(self, causal_block_size): for block in self.blocks: block.set_ar_attention() + @staticmethod + def _prepare_blockwise_causal_attn_mask( + device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 + ) -> BlockMask: + """ + we will divide the token sequence into the following format [1 latent frame] [1 latent frame] ... [1 latent + frame] We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device) + + for tmp in frame_indices: + ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block + + def attention_mask(b, h, q_idx, kv_idx): + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask + + block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device, + ) + + return block_mask + + def initialize_teacache( + self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir="" + ): + self.enable_teacache = enable_teacache + print("using teacache") + self.cnt = 0 + self.num_steps = num_steps + self.teacache_thresh = teacache_thresh + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.use_ref_steps = use_ret_steps + if "I2V" in ckpt_dir: + if use_ret_steps: + if "540P" in ckpt_dir: + self.coefficients = [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01] + if "720P" in ckpt_dir: + self.coefficients = [8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02] + self.ret_steps = 5 * 2 + self.cutoff_steps = num_steps * 2 + else: + if "540P" in ckpt_dir: + self.coefficients = [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01] + if "720P" in ckpt_dir: + self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] + self.ret_steps = 1 * 2 + self.cutoff_steps = num_steps * 2 - 2 + else: + if use_ret_steps: + if "1.3B" in ckpt_dir: + self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02] + if "14B" in ckpt_dir: + self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01] + self.ret_steps = 5 * 2 + self.cutoff_steps = num_steps * 2 + else: + if "1.3B" in ckpt_dir: + self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01] + if "14B" in ckpt_dir: + self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] + self.ret_steps = 1 * 2 + self.cutoff_steps = num_steps * 2 - 2 + def init_weights(self): r""" Initialize model parameters using Xavier initialization. @@ -888,3 +891,9 @@ def init_weights(self): # init output layer nn.init.zeros_(self.head.head.weight) + + def zero_init_i2v_cross_attn(self): + print("zero init i2v cross attn") + for i in range(self.num_layers): + self.blocks[i].cross_attn.v_img.weight.data.zero_() + self.blocks[i].cross_attn.v_img.bias.data.zero_() \ No newline at end of file From 825c2c1b7829ace815f2d0ed15c83bdf15fb6e46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 12 May 2025 10:36:11 +0300 Subject: [PATCH 023/264] Refactors `SkyReelsV2TransformerBlock` to integrate its `forward()` method --- .../transformers/transformer_skyreels_v2.py | 97 ++++++++----------- 1 file changed, 38 insertions(+), 59 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 22d16f703ea0..cbc35633a6bc 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -192,17 +192,6 @@ def forward(self, x, context): return x -def mul_add(x, y, z): - return x.float() + y.float() * z.float() - - -def mul_add_add(x, y, z): - return x.float() * (1 + y) + z - - -mul_add_compile = torch.compile(mul_add, dynamic=True, disable=DISABLE_COMPILE) -mul_add_add_compile = torch.compile(mul_add_add, dynamic=True, disable=DISABLE_COMPILE) - class SkyReelsV2ImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -316,20 +305,13 @@ def __init__( dim: int, ffn_dim: int, num_heads: int, - window_size=(-1, -1), + window_size: Tuple[int, int] = (-1, -1), qk_norm: str = "rms_norm", cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, ): super().__init__() - self.dim = dim - self.ffn_dim = ffn_dim - self.num_heads = num_heads - self.window_size = window_size - self.qk_norm = qk_norm - self.cross_attn_norm = cross_attn_norm - self.eps = eps # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) @@ -349,7 +331,7 @@ def __init__( # 2. Cross-attention self.attn2 = Attention( - window_size=(-1, -1), + window_size=window_size, query_dim=dim, heads=num_heads, kv_heads=num_heads, @@ -361,29 +343,26 @@ def __init__( out_bias=True, added_kv_proj_dim=added_kv_proj_dim, added_proj_bias=True, - pre_only=False, + pre_only=False, # TODO: check processor=SkyReelsV2AttnProcessor2_0(), ) - self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.norm2 = FP32LayerNorm(dim, eps) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") - self.norm3 = FP32LayerNorm(dim, eps) if cross_attn_norm else nn.Identity() - + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) - def set_ar_attention(self): - self.self_attn.set_ar_attention() - def forward( self, - x, - e, - grid_sizes, - freqs, - context, - block_mask, - ): + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb_e: torch.Tensor, + rotary_emb: torch.Tensor, + grid_sizes: torch.Tensor, + freqs: torch.Tensor, + block_mask: torch.Tensor, + ) -> torch.Tensor: r""" Args: x(Tensor): Shape [B, L, C] @@ -392,33 +371,33 @@ def forward( grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ - if e.dim() == 3: - modulation = self.modulation # 1, 6, dim - with amp.autocast("cuda", dtype=torch.float32): - e = (modulation + e).chunk(6, dim=1) - elif e.dim() == 4: - modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim - with amp.autocast("cuda", dtype=torch.float32): - e = (modulation + e).chunk(6, dim=1) - e = [ei.squeeze(1) for ei in e] + if temb_e.dim() == 3: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb_e.float()).chunk(6, dim=1) + elif temb_e.dim() == 4: + e = (self.scale_shift_table.unsqueeze(2) + temb_e.float()).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] # self-attention - out = mul_add_add_compile(self.norm1(x), e[1], e[0]) - y = self.self_attn(out, grid_sizes, freqs, block_mask) - with amp.autocast("cuda", dtype=torch.float32): - x = mul_add_compile(x, y, e[2]) - - # cross-attention & ffn function - def cross_attn_ffn(x, context, e): - dtype = context.dtype - x = x + self.cross_attn(self.norm3(x.to(dtype)), context) - y = self.ffn(mul_add_add_compile(self.norm2(x), e[4], e[3]).to(dtype)) - with amp.autocast("cuda", dtype=torch.float32): - x = mul_add_compile(x, y, e[5]) - return x - - x = cross_attn_ffn(x, context, e) - return x.to(torch.bfloat16) + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.self_attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, grid_sizes=grid_sizes, freqs=freqs, block_mask=block_mask) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states # TODO: check .to(torch.bfloat16) + + def set_ar_attention(self): + self.self_attn.set_ar_attention() class Head(nn.Module): From d848500553c5109be6b60c777874774d3cd6a8a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 12 May 2025 16:19:16 +0300 Subject: [PATCH 024/264] Refactors `SkyReelsV2AttnProcessor2_0` to enhance the `forward()` method, integrating rotary embeddings and improving attention handling. Removes the deprecated `rope_apply` function and streamlines the attention mechanism for better integration and clarity. --- .../transformers/transformer_skyreels_v2.py | 215 +++++++----------- 1 file changed, 76 insertions(+), 139 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index cbc35633a6bc..e4e4e51a8b2b 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -39,38 +39,6 @@ flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") -DISABLE_COMPILE = False # get os env - - -@amp.autocast("cuda", enabled=False) -def rope_apply(x, grid_sizes, freqs): - n, c = x.size(2), x.size(3) // 2 - bs = x.size(0) - - # split freqs - freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) - - # loop over samples - f, h, w = grid_sizes.tolist() - seq_len = f * h * w - - # precompute multipliers - - x = torch.view_as_complex(x.to(torch.float32).reshape(bs, seq_len, n, -1, 2)) - freqs_i = torch.cat( - [ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), - ], - dim=-1, - ).reshape(seq_len, 1, -1) - - # apply rotary embedding - x = torch.view_as_real(x * freqs_i).flatten(3) - - return x - class SkyReelsV2AttnProcessor2_0: def __init__(self): @@ -81,115 +49,86 @@ def __init__(self): self._flag_ar_attention = False - def set_ar_attention(self): - self._flag_ar_attention = True - - def forward(self, x, grid_sizes, freqs, block_mask): - r""" - Args: - x(Tensor): Shape [B, L, num_heads, C / num_heads] - seq_lens(Tensor): Shape [B] - grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) - freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - """ - b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim - - # query, key, value function - def qkv_fn(x): - q = self.norm_q(self.q(x)).view(b, s, n, d) - k = self.norm_k(self.k(x)).view(b, s, n, d) - v = self.v(x).view(b, s, n, d) - return q, k, v - - x = x.to(self.q.weight.dtype) - q, k, v = qkv_fn(x) - - if not self._flag_ar_attention: - q = rope_apply(q, grid_sizes, freqs) - k = rope_apply(k, grid_sizes, freqs) - x = flash_attention(q=q, k=k, v=v, window_size=self.window_size) + def forward( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + # image_context_length is hardcoded for now like in the original code + image_context_length = 257 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + if self._flag_ar_attention: + is_self_attention = encoder_hidden_states == hidden_states + hidden_states = F.scaled_dot_product_attention( + query.to(torch.bfloat16) if is_self_attention else query, + key.to(torch.bfloat16) if is_self_attention else key, + value.to(torch.bfloat16) if is_self_attention else value, + attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) else: - q = rope_apply(q, grid_sizes, freqs) - k = rope_apply(k, grid_sizes, freqs) - q = q.to(torch.bfloat16) - k = k.to(torch.bfloat16) - v = v.to(torch.bfloat16) - - with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): - x = ( - torch.nn.functional.scaled_dot_product_attention( - q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), attn_mask=block_mask - ) - .transpose(1, 2) - .contiguous() - ) - - # output - x = x.flatten(2) - x = self.o(x) - return x - - -class WanT2VCrossAttention(SkyReelsV2AttnProcessor2_0): - def forward(self, x, context): - r""" - Args: - x(Tensor): Shape [B, L1, C] - context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] - """ - b, n, d = x.size(0), self.num_heads, self.head_dim + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) - # compute attention - x = flash_attention(q, k, v) + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img - # output - x = x.flatten(2) - x = self.o(x) - return x - - -class WanI2VCrossAttention(SkyReelsV2AttnProcessor2_0): - def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): - super().__init__(dim, num_heads, window_size, qk_norm, eps) + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states - self.k_img = nn.Linear(dim, dim) - self.v_img = nn.Linear(dim, dim) - # self.alpha = nn.Parameter(torch.zeros((1, ))) - self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + def set_ar_attention(self): + self._flag_ar_attention = True - def forward(self, x, context): - r""" - Args: - x(Tensor): Shape [B, L1, C] - context(Tensor): Shape [B, L2, C] - context_lens(Tensor): Shape [B] - """ - context_img = context[:, :257] - context = context[:, 257:] - b, n, d = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.norm_q(self.q(x)).view(b, -1, n, d) - k = self.norm_k(self.k(context)).view(b, -1, n, d) - v = self.v(context).view(b, -1, n, d) - k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) - v_img = self.v_img(context_img).view(b, -1, n, d) - img_x = flash_attention(q, k_img, v_img) - # compute attention - x = flash_attention(q, k, v) - - # output - x = x.flatten(2) - img_x = img_x.flatten(2) - x = x + img_x - x = self.o(x) - return x class SkyReelsV2ImageEmbedding(torch.nn.Module): @@ -342,8 +281,6 @@ def __init__( cross_attention_dim=None, out_bias=True, added_kv_proj_dim=added_kv_proj_dim, - added_proj_bias=True, - pre_only=False, # TODO: check processor=SkyReelsV2AttnProcessor2_0(), ) self.norm2 = FP32LayerNorm(dim, eps) if cross_attn_norm else nn.Identity() @@ -379,7 +316,7 @@ def forward( # self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.self_attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, grid_sizes=grid_sizes, freqs=freqs, block_mask=block_mask) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, grid_sizes=grid_sizes, freqs=freqs, block_mask=block_mask) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -397,7 +334,7 @@ def forward( return hidden_states # TODO: check .to(torch.bfloat16) def set_ar_attention(self): - self.self_attn.set_ar_attention() + self.attn1.processor.set_ar_attention() class Head(nn.Module): From 2f5a4e2494e3a29800551d7c2a419f1f4e55b6db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 12 May 2025 18:04:35 +0300 Subject: [PATCH 025/264] Refactors `SkyReelsV2Transformer3DModel` to enhance the `forward()` method by updating parameter names for clarity, integrating attention masks, and improving the handling of encoder hidden states. --- .../transformers/transformer_skyreels_v2.py | 75 +++++++++---------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index e4e4e51a8b2b..34fb73713ba2 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,29 +13,27 @@ # limitations under the License. import math -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Any, Union import numpy as np import torch import torch.amp as amp import torch.nn as nn import torch.nn.functional as F -from torch.backends.cuda import sdp_kernel from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import logging +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from .attention import flash_attention -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger = logging.get_logger(__name__) flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") @@ -130,7 +128,6 @@ def set_ar_attention(self): self._flag_ar_attention = True - class SkyReelsV2ImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() @@ -296,27 +293,19 @@ def forward( encoder_hidden_states: torch.Tensor, temb_e: torch.Tensor, rotary_emb: torch.Tensor, - grid_sizes: torch.Tensor, - freqs: torch.Tensor, - block_mask: torch.Tensor, + attention_mask: torch.Tensor, ) -> torch.Tensor: - r""" - Args: - x(Tensor): Shape [B, L, C] - e(Tensor): Shape [B, 6, C] - seq_lens(Tensor): Shape [B], length of each sequence in batch - grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) - freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - """ if temb_e.dim() == 3: - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb_e.float()).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb_e.float() + ).chunk(6, dim=1) elif temb_e.dim() == 4: e = (self.scale_shift_table.unsqueeze(2) + temb_e.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] # self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, grid_sizes=grid_sizes, freqs=freqs, block_mask=block_mask) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -418,7 +407,6 @@ def __init__( out_channels: int = 16, num_attention_heads: int = 16, num_layers: int = 32, - window_size: Tuple[int, int] = (-1, -1), qk_norm: Optional[str] = "rms_norm", cross_attn_norm: bool = True, inject_sample_info: bool = False, @@ -435,9 +423,7 @@ def __init__( self.num_frame_per_block = 1 self.flag_causal_attention = False - self.block_mask = None self.enable_teacache = False - self.inject_sample_info = inject_sample_info # 1. Patch & position embedding self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) @@ -459,8 +445,7 @@ def __init__( [ SkyReelsV2TransformerBlock( inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, - window_size, # TODO: check - added_kv_proj_dim # TODO: check + added_kv_proj_dim=inner_dim ) for _ in range(num_layers) ] @@ -480,10 +465,19 @@ def __init__( # TODO: Say: Initializing suggested by the original repo? # self.init_weights() - def forward(self, x, t, context, clip_fea=None, y=None, fps=None): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: r""" - Forward pass through the diffusion model - Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] @@ -524,8 +518,8 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): post_patch_height = height // p_h post_patch_width = width // p_w - if self.model_type == "i2v": - assert clip_fea is not None and y is not None + rotary_emb = self.rope(hidden_states) + # params device = self.patch_embedding.weight.device if self.freqs.device != device: @@ -550,18 +544,18 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) casual_mask = casual_mask.repeat(1, height, width, 1, height, width) casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) - self.block_mask = casual_mask.unsqueeze(0).unsqueeze(0) + block_mask = casual_mask.unsqueeze(0).unsqueeze(0) # time embeddings with amp.autocast("cuda", dtype=torch.float32): - if t.dim() == 2: - b, f = t.shape + if timestep.dim() == 2: + b, f = timestep.shape _flag_df = True else: _flag_df = False e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype) + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(self.patch_embedding.weight.dtype) ) # b, dim e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim @@ -570,7 +564,7 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): fps_emb = self.fps_embedding(fps).float() if _flag_df: - e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1) + e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(timestep.shape[1], 1, 1) else: e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) @@ -584,19 +578,18 @@ def forward(self, x, t, context, clip_fea=None, y=None, fps=None): assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context - context = self.text_embedding(context) + encoder_hidden_states = self.text_embedding(encoder_hidden_states) - if clip_fea is not None: - context_clip = self.img_emb(clip_fea) # bs x 257 x dim - context = torch.concat([context_clip, context], dim=1) + if encoder_hidden_states_image is not None: + context_clip = self.img_emb(encoder_hidden_states_image) # bs x 257 x dim + encoder_hidden_states = torch.concat([context_clip, encoder_hidden_states], dim=1) # arguments kwargs = { "e": e0, - "grid_sizes": grid_sizes, "freqs": self.freqs, - "context": context, - "block_mask": self.block_mask, + "encoder_hidden_states": encoder_hidden_states, + "attention_mask": block_mask, } if self.enable_teacache: modulated_inp = e0 if self.use_ref_steps else e From e5870dd2407e714e0005d372d9f30ecce3d5f261 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 12 May 2025 18:24:51 +0300 Subject: [PATCH 026/264] Refactors `SkyReelsV2Transformer3DModel` to improve the `forward()` method by enhancing the handling of time embeddings and encoder hidden states. Updates parameter names for clarity and integrates rotary embeddings, ensuring better compatibility with the model's architecture. --- .../transformers/transformer_skyreels_v2.py | 121 ++++++++++-------- 1 file changed, 68 insertions(+), 53 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 34fb73713ba2..6828d5440a03 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, Tuple, Dict, Any, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import torch @@ -29,6 +29,7 @@ from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -109,7 +110,9 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): query.to(torch.bfloat16) if is_self_attention else query, key.to(torch.bfloat16) if is_self_attention else key, value.to(torch.bfloat16) if is_self_attention else value, - attn_mask=attention_mask, dropout_p=0.0, is_causal=False + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, ) else: hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) @@ -151,6 +154,7 @@ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: hidden_states = self.norm2(hidden_states) return hidden_states + class SkyReelsV2TimeTextImageEmbedding(nn.Module): def __init__( self, @@ -193,6 +197,7 @@ def forward( return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + class SkyReelsV2RotaryPosEmbed(nn.Module): def __init__( self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 @@ -235,6 +240,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) return freqs + class SkyReelsV2TransformerBlock(nn.Module): def __init__( self, @@ -298,14 +304,16 @@ def forward( if temb_e.dim() == 3: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb_e.float() - ).chunk(6, dim=1) + ).chunk(6, dim=1) elif temb_e.dim() == 4: e = (self.scale_shift_table.unsqueeze(2) + temb_e.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] # self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask) + attn_output = self.attn1( + hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask + ) hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention @@ -444,8 +452,7 @@ def __init__( self.blocks = nn.ModuleList( [ SkyReelsV2TransformerBlock( - inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, - added_kv_proj_dim=inner_dim + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim=inner_dim ) for _ in range(num_layers) ] @@ -460,7 +467,9 @@ def __init__( if inject_sample_info: self.fps_embedding = nn.Embedding(2, inner_dim) - self.fps_projection = nn.Sequential(nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6)) + self.fps_projection = nn.Sequential( + nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6) + ) # TODO: Say: Initializing suggested by the original repo? # self.init_weights() @@ -526,12 +535,12 @@ def forward( self.freqs = self.freqs.to(device) if y is not None: - x = torch.cat([x, y], dim=1) + hidden_states = torch.cat([hidden_states, y], dim=1) # embeddings - x = self.patch_embedding(x) - grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long) - x = x.flatten(2).transpose(1, 2) + hidden_states = self.patch_embedding(hidden_states) + grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) + hidden_states = hidden_states.flatten(2).transpose(1, 2) if self.flag_causal_attention: frame_num = grid_sizes[0] @@ -541,58 +550,57 @@ def forward( range_tensor = torch.arange(block_num).view(-1, 1) range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f - casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(x.device) + casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(hidden_states.device) casual_mask = casual_mask.repeat(1, height, width, 1, height, width) casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) block_mask = casual_mask.unsqueeze(0).unsqueeze(0) # time embeddings - with amp.autocast("cuda", dtype=torch.float32): - if timestep.dim() == 2: - b, f = timestep.shape - _flag_df = True - else: - _flag_df = False - - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(self.patch_embedding.weight.dtype) - ) # b, dim - e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim + if timestep.dim() == 2: + b, f = timestep.shape + _flag_df = True + else: + _flag_df = False - if self.inject_sample_info: - fps = torch.tensor(fps, dtype=torch.long, device=device) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) - fps_emb = self.fps_embedding(fps).float() - if _flag_df: - e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(timestep.shape[1], 1, 1) - else: - e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) + if self.inject_sample_info: + fps = torch.tensor(fps, dtype=torch.long, device=device) + fps_emb = self.fps_embedding(fps).float() if _flag_df: - e = e.view(b, f, 1, 1, self.dim) - e0 = e0.view(b, f, 1, 1, 6, self.dim) - e = e.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) - e0 = e0.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) - e0 = e0.transpose(1, 2).contiguous() + timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat( + timestep.shape[1], 1, 1 + ) + else: + timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) - assert e.dtype == torch.float32 and e0.dtype == torch.float32 + if _flag_df: + temb = temb.view(b, f, 1, 1, self.dim) + timestep_proj = timestep_proj.view(b, f, 1, 1, 6, self.dim) + temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) + timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) + timestep_proj = timestep_proj.transpose(1, 2).contiguous() # context encoder_hidden_states = self.text_embedding(encoder_hidden_states) if encoder_hidden_states_image is not None: - context_clip = self.img_emb(encoder_hidden_states_image) # bs x 257 x dim - encoder_hidden_states = torch.concat([context_clip, encoder_hidden_states], dim=1) + encoder_hidden_states_image = self.img_emb(encoder_hidden_states_image) # bs x 257 x dim + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # arguments kwargs = { - "e": e0, - "freqs": self.freqs, + "temb": timestep_proj, + "rotary_emb": rotary_emb, "encoder_hidden_states": encoder_hidden_states, "attention_mask": block_mask, } if self.enable_teacache: - modulated_inp = e0 if self.use_ref_steps else e + modulated_inp = timestep_proj if self.use_ref_steps else temb # teacache if self.cnt % 2 == 0: # even -> condition self.is_even = True @@ -635,27 +643,27 @@ def forward( if self.enable_teacache: if self.is_even: if not should_calc_even: - x += self.previous_residual_even + hidden_states += self.previous_residual_even else: - ori_x = x.clone() + ori_hidden_states = hidden_states.clone() for block in self.blocks: - x = block(x, **kwargs) - self.previous_residual_even = x - ori_x + hidden_states = block(hidden_states, **kwargs) + self.previous_residual_even = hidden_states - ori_hidden_states else: if not should_calc_odd: - x += self.previous_residual_odd + hidden_states += self.previous_residual_odd else: - ori_x = x.clone() + ori_hidden_states = hidden_states.clone() for block in self.blocks: - x = block(x, **kwargs) - self.previous_residual_odd = x - ori_x + hidden_states = block(hidden_states, **kwargs) + self.previous_residual_odd = hidden_states - ori_hidden_states self.cnt += 1 if self.cnt >= self.num_steps: self.cnt = 0 else: for block in self.blocks: - x = block(x, **kwargs) + hidden_states = block(hidden_states, **kwargs) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) @@ -674,9 +682,16 @@ def forward( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3).float() + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) - return x.float() + return Transformer2DModelOutput(sample=output) def set_ar_attention(self, causal_block_size): self.num_frame_per_block = causal_block_size @@ -805,4 +820,4 @@ def zero_init_i2v_cross_attn(self): print("zero init i2v cross attn") for i in range(self.num_layers): self.blocks[i].cross_attn.v_img.weight.data.zero_() - self.blocks[i].cross_attn.v_img.bias.data.zero_() \ No newline at end of file + self.blocks[i].cross_attn.v_img.bias.data.zero_() From d54e3e1e51e7027801ee119c3098b1051bd13c9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 10:19:35 +0300 Subject: [PATCH 027/264] Refactors `SkyReelsV2Transformer3DModel` forward pass --- .../transformers/transformer_skyreels_v2.py | 57 +++++++------------ 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 6828d5440a03..af24ceab79fb 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -420,7 +420,6 @@ def __init__( inject_sample_info: bool = False, eps: float = 1e-6, image_dim: Optional[int] = None, - added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, pos_embed_seq_len: Optional[int] = None, ): @@ -480,7 +479,6 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, return_dict: bool = True, @@ -529,31 +527,25 @@ def forward( rotary_emb = self.rope(hidden_states) - # params - device = self.patch_embedding.weight.device - if self.freqs.device != device: - self.freqs = self.freqs.to(device) - + # TODO: check here if y is not None: hidden_states = torch.cat([hidden_states, y], dim=1) - # embeddings hidden_states = self.patch_embedding(hidden_states) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) - hidden_states = hidden_states.flatten(2).transpose(1, 2) if self.flag_causal_attention: - frame_num = grid_sizes[0] - height = grid_sizes[1] - width = grid_sizes[2] + frame_num, height, width = grid_sizes block_num = frame_num // self.num_frame_per_block - range_tensor = torch.arange(block_num).view(-1, 1) + range_tensor = torch.arange(block_num, device=hidden_states.device).view(-1, 1) range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() - casual_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f - casual_mask = casual_mask.view(frame_num, 1, 1, frame_num, 1, 1).to(hidden_states.device) - casual_mask = casual_mask.repeat(1, height, width, 1, height, width) - casual_mask = casual_mask.reshape(frame_num * height * width, frame_num * height * width) - block_mask = casual_mask.unsqueeze(0).unsqueeze(0) + causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f + causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1) + causal_mask = causal_mask.repeat(1, height, width, 1, height, width) + causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width) + causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) # time embeddings if timestep.dim() == 2: @@ -567,8 +559,11 @@ def forward( ) timestep_proj = timestep_proj.unflatten(1, (6, -1)) + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + if self.inject_sample_info: - fps = torch.tensor(fps, dtype=torch.long, device=device) + fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() if _flag_df: @@ -585,20 +580,6 @@ def forward( timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() - # context - encoder_hidden_states = self.text_embedding(encoder_hidden_states) - - if encoder_hidden_states_image is not None: - encoder_hidden_states_image = self.img_emb(encoder_hidden_states_image) # bs x 257 x dim - encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) - - # arguments - kwargs = { - "temb": timestep_proj, - "rotary_emb": rotary_emb, - "encoder_hidden_states": encoder_hidden_states, - "attention_mask": block_mask, - } if self.enable_teacache: modulated_inp = timestep_proj if self.use_ref_steps else temb # teacache @@ -647,7 +628,9 @@ def forward( else: ori_hidden_states = hidden_states.clone() for block in self.blocks: - hidden_states = block(hidden_states, **kwargs) + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask + ) self.previous_residual_even = hidden_states - ori_hidden_states else: if not should_calc_odd: @@ -655,7 +638,9 @@ def forward( else: ori_hidden_states = hidden_states.clone() for block in self.blocks: - hidden_states = block(hidden_states, **kwargs) + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask + ) self.previous_residual_odd = hidden_states - ori_hidden_states self.cnt += 1 @@ -663,7 +648,7 @@ def forward( self.cnt = 0 else: for block in self.blocks: - hidden_states = block(hidden_states, **kwargs) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask) # 5. Output norm, projection & unpatchify shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) From 10d74808966d436bb6bb1c4433ec10737db49bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 10:56:58 +0300 Subject: [PATCH 028/264] Add DF inference template. --- .../transformers/transformer_skyreels_v2.py | 2 +- .../pipeline_skyreels_v2_diffusion_forcing.py | 344 ++++++++++++------ 2 files changed, 227 insertions(+), 119 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index af24ceab79fb..135ece77d55f 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -547,7 +547,7 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) - # time embeddings + # TODO: check here if timestep.dim() == 2: b, f = timestep.shape _flag_df = True diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 6175ab918882..22a1184435a7 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -15,6 +15,8 @@ import html import re from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +from tqdm import tqdm import ftfy import PIL.Image @@ -584,126 +586,232 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # 3. Encode input prompt - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - max_sequence_length=max_sequence_length, - device=device, - ) - - # Encode image embedding - transformer_dtype = self.transformer.dtype - prompt_embeds = prompt_embeds.to(transformer_dtype) - if negative_prompt_embeds is not None: - negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - - if image_embeds is None: - image_embeds = self.encode_image(image, device) - - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps - - # 5. Prepare latent variables - num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - - latents, condition = self.prepare_latents( - image, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents, - ) - - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) - timestep = t.expand(latents.shape[0]) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latent_height = height // 8 + latent_width = width // 8 + latent_length = (num_frames - 1) // 4 + 1 - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - self._current_timestep = None + self._guidance_scale = guidance_scale - if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + i2v_extra_kwrags = {} + prefix_video = None + predix_video_latent_length = 0 + if image: + prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) + + self.text_encoder.to(self.device) + prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) + if self.do_classifier_free_guidance: + negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype) + if self.offload: + self.text_encoder.cpu() + torch.cuda.empty_cache() + + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + init_timesteps = self.scheduler.timesteps + if causal_block_size is None: + causal_block_size = self.transformer.num_frame_per_block + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + transformer_dtype = self.transformer.dtype + # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad(): + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: + # short video generation + latent_shape = [16, latent_length, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + sample_schedulers = [] + for _ in range(latent_length): + sample_scheduler = FlowMatchUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + self.transformer.to(self.device) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + else: + noise_pred_cond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + noise_pred_uncond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=negative_prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if self.offload: + self.transformer.cpu() + torch.cuda.empty_cache() + x0 = latents[0].unsqueeze(0) + videos = self.vae.decode(x0) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + videos = [video.cpu().numpy().astype(np.uint8) for video in videos] + return videos else: - video = latents - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (video,) - - return SkyReelsV2PipelineOutput(frames=video) + # long video generation + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + overlap_history_frames = (overlap_history - 1) // 4 + 1 + n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + print(f"n_iter:{n_iter}") + output_video = None + for i in range(n_iter): + if output_video is not None: # i !=0 + prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = latent_length - finished_frame_num + base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) + if ar_step > 0 and self.transformer.enable_teacache: + num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + self.transformer.num_steps = num_steps + else: # i == 0 + base_num_frames_iter = base_num_frames + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + init_timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowMatchUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + self.transformer.to(self.device) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + else: + noise_pred_cond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + noise_pred_uncond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=negative_prompt_embeds, + fps=fps_embeds, + **i2v_extra_kwrags, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if self.offload: + self.transformer.cpu() + torch.cuda.empty_cache() + x0 = latents[0].unsqueeze(0) + videos = [self.vae.decode(x0)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video for video in output_video] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + + if not return_dict: + return (output_video,) + + return SkyReelsV2PipelineOutput(frames=output_video) From fc68bf3fa3a7c14f9fe14d2213b19deaa511a0ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 13:58:27 +0300 Subject: [PATCH 029/264] style --- .../transformers/transformer_skyreels_v2.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 135ece77d55f..687da6353dd4 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -48,7 +48,7 @@ def __init__(self): self._flag_ar_attention = False - def forward( + def __call__( self, attn: Attention, hidden_states: torch.Tensor, @@ -247,7 +247,6 @@ def __init__( dim: int, ffn_dim: int, num_heads: int, - window_size: Tuple[int, int] = (-1, -1), qk_norm: str = "rms_norm", cross_attn_norm: bool = False, eps: float = 1e-6, @@ -258,7 +257,6 @@ def __init__( # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) self.attn1 = Attention( - window_size=window_size, query_dim=dim, heads=num_heads, kv_heads=num_heads, @@ -273,7 +271,6 @@ def __init__( # 2. Cross-attention self.attn2 = Attention( - window_size=window_size, query_dim=dim, heads=num_heads, kv_heads=num_heads, @@ -284,32 +281,34 @@ def __init__( cross_attention_dim=None, out_bias=True, added_kv_proj_dim=added_kv_proj_dim, + added_proj_bias=True, processor=SkyReelsV2AttnProcessor2_0(), ) - self.norm2 = FP32LayerNorm(dim, eps) if cross_attn_norm else nn.Identity() + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() # 3. Feed-forward self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, - temb_e: torch.Tensor, + temb: torch.Tensor, rotary_emb: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: - if temb_e.dim() == 3: + if temb.dim() == 3: shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb_e.float() + self.scale_shift_table + temb.float() ).chunk(6, dim=1) - elif temb_e.dim() == 4: - e = (self.scale_shift_table.unsqueeze(2) + temb_e.float()).chunk(6, dim=1) + elif temb.dim() == 4: + e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] - # self-attention + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask @@ -407,22 +406,23 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr def __init__( self, patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 16, attention_head_dim: int = 128, in_channels: int = 16, - ffn_dim: int = 8192, - freq_dim: int = 256, - text_dim: int = 4096, out_channels: int = 16, - num_attention_heads: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 8192, num_layers: int = 32, - qk_norm: Optional[str] = "rms_norm", cross_attn_norm: bool = True, - inject_sample_info: bool = False, + qk_norm: Optional[str] = "rms_norm", eps: float = 1e-6, image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, rope_max_seq_len: int = 1024, pos_embed_seq_len: Optional[int] = None, - ): + inject_sample_info: bool = False, + ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim @@ -433,8 +433,8 @@ def __init__( self.enable_teacache = False # 1. Patch & position embedding - self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) # 2. Condition embeddings # image_embedding_dim=1280 for I2V model @@ -667,7 +667,7 @@ def forward( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3).float() + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From 1cb6a9e94e25213a3a40b42494dd3f9cdc6ea8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 16:17:11 +0300 Subject: [PATCH 030/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to remove image processing components and streamline the text-to-video generation process. Updates class documentation and adjusts parameter handling for improved clarity and functionality. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 166 +++++------------- 1 file changed, 45 insertions(+), 121 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 22a1184435a7..9a0adda188e4 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -100,8 +100,7 @@ def retrieve_latents( class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for video generation with diffusion forcing using SkyReels-V2. This pipeline supports two main tasks: - Text-to-Video (t2v) and Image-to-Video (i2v) + Pipeline for text-to-video generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -113,11 +112,6 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`UMT5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - image_encoder ([`CLIPVisionModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically - the - [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) - variant. transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): @@ -126,15 +120,13 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, @@ -145,17 +137,15 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, - image_processor=image_processor, ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.image_processor = image_processor + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -197,16 +187,6 @@ def _get_t5_prompt_embeds( return prompt_embeds - def encode_image( - self, - image: PipelineImageInput, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - image = self.image_processor(images=image, return_tensors="pt").to(device) - image_embeds = self.image_encoder(**image, output_hidden_states=True) - return image_embeds.hidden_states[-2] - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, @@ -289,29 +269,17 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, negative_prompt, - image, height, width, prompt_embeds=None, negative_prompt_embeds=None, - image_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if image is not None and image_embeds is not None: - raise ValueError( - f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" - " only forward one of the two." - ) - if image is None and image_embeds is None: - raise ValueError( - "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." - ) - if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): - raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -343,9 +311,9 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, - image: PipelineImageInput, batch_size: int, num_channels_latents: int = 16, height: int = 480, @@ -355,62 +323,26 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - latent_height = height // self.vae_scale_factor_spatial - latent_width = width // self.vae_scale_factor_spatial + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) - shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device=device, dtype=dtype) - - image = image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - - video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator - ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - - latent_condition = latent_condition.to(dtype) - latent_condition = (latent_condition - latents_mean) * latents_std - - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(latent_condition.device) - - return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents @property def guidance_scale(self): @@ -418,7 +350,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 1.0 @property def num_timesteps(self): @@ -442,7 +374,6 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, height: int = 480, width: int = 832, num_frames: int = 97, @@ -453,7 +384,6 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -461,7 +391,7 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 512, # TODO: check overlap_history: Optional[int] = None, shift: float = 1.0, addnoise_condition: float = 0.0, @@ -474,8 +404,6 @@ def __call__( The call function to the pipeline for generation. Args: - image (`PipelineImageInput`): - The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -487,7 +415,7 @@ def __call__( The height of the generated video. width (`int`, defaults to `832`): The width of the generated video. - num_frames (`int`, defaults to `81`): + num_frames (`int`, defaults to `97`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -513,9 +441,6 @@ def __call__( negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `negative_prompt` input argument. - image_embeds (`torch.Tensor`, *optional*): - Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, - image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -539,6 +464,7 @@ def __call__( The shift of the flow. autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): The dtype to use for the torch.amp.autocast. + Examples: Returns: @@ -555,12 +481,10 @@ def __call__( self.check_inputs( prompt, negative_prompt, - image, height, width, prompt_embeds, negative_prompt_embeds, - image_embeds, callback_on_step_end_tensor_inputs, ) @@ -586,28 +510,34 @@ def __call__( else: batch_size = prompt_embeds.shape[0] + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + latent_height = height // 8 latent_width = width // 8 latent_length = (num_frames - 1) // 4 + 1 - self._guidance_scale = guidance_scale + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps - i2v_extra_kwrags = {} prefix_video = None predix_video_latent_length = 0 - if image: - prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) - - self.text_encoder.to(self.device) - prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) - if self.do_classifier_free_guidance: - negative_prompt_embeds = self.text_encoder.encode(negative_prompt).to(self.transformer.dtype) - if self.offload: - self.text_encoder.cpu() - torch.cuda.empty_cache() - - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - init_timesteps = self.scheduler.timesteps + if causal_block_size is None: causal_block_size = self.transformer.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] @@ -625,7 +555,7 @@ def __call__( latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) sample_schedulers = [] for _ in range(latent_length): @@ -657,7 +587,6 @@ def __call__( t=timestep, context=prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] else: noise_pred_cond = self.transformer( @@ -665,14 +594,12 @@ def __call__( t=timestep, context=prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] noise_pred_uncond = self.transformer( torch.stack([latent_model_input[0]]), t=timestep, context=negative_prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) for idx in range(valid_interval_start, valid_interval_end): @@ -728,7 +655,7 @@ def __call__( latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, - init_timesteps, + timesteps, base_num_frames_iter, ar_step, predix_video_latent_length, @@ -767,7 +694,6 @@ def __call__( t=timestep, context=prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] else: noise_pred_cond = self.transformer( @@ -775,14 +701,12 @@ def __call__( t=timestep, context=prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] noise_pred_uncond = self.transformer( torch.stack([latent_model_input[0]]), t=timestep, context=negative_prompt_embeds, fps=fps_embeds, - **i2v_extra_kwrags, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) for idx in range(valid_interval_start, valid_interval_end): From ded93bc220db0c953c555a518edd242179750efd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 17:04:07 +0300 Subject: [PATCH 031/264] Enhance `SkyReelsV2DiffusionForcingImageToVideoPipeline` by refining parameter handling and improving integration. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 292 ++++++++++++------ 1 file changed, 202 insertions(+), 90 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 094c166eda98..859573f7bd06 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -40,7 +40,7 @@ else: XLA_AVAILABLE = False -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +logger = logging.get_logger(__name__) if is_ftfy_available(): import ftfy @@ -438,9 +438,9 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, + image: PipelineImageInput, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, height: int = 480, width: int = 832, num_frames: int = 97, @@ -459,7 +459,7 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, + max_sequence_length: int = 512, # TODO: check overlap_history: Optional[int] = None, shift: float = 1.0, addnoise_condition: float = 0.0, @@ -485,7 +485,7 @@ def __call__( The height of the generated video. width (`int`, defaults to `832`): The width of the generated video. - num_frames (`int`, defaults to `81`): + num_frames (`int`, defaults to `97`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -576,6 +576,15 @@ def __call__( device = self._execution_device + latent_height = height // 8 + latent_width = width // 8 + latent_length = (num_frames - 1) // 4 + 1 + + if causal_block_size is None: + causal_block_size = self.transformer.num_frame_per_block + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -585,6 +594,7 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt + #prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -596,15 +606,15 @@ def __call__( device=device, ) - # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + # Encode image embedding + #prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) if image_embeds is None: image_embeds = self.encode_image(image, device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) @@ -612,93 +622,195 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - # 5. Prepare latent variables - num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - - latents, condition = self.prepare_latents( - image, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents, - ) - - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - self._current_timestep = t - latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) - timestep = t.expand(latents.shape[0]) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - self._current_timestep = None - - if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: + # short video generation + latent_shape = [16, latent_length, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] - video = self.video_processor.postprocess_video(video, output_type=output_type) + sample_schedulers = [] + for _ in range(latent_length): + sample_scheduler = FlowMatchUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + self.transformer.to(self.device) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + )[0] + else: + noise_pred_cond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + )[0] + noise_pred_uncond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=negative_prompt_embeds, + fps=fps_embeds, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if self.offload: + self.transformer.cpu() + torch.cuda.empty_cache() + x0 = latents[0].unsqueeze(0) + videos = self.vae.decode(x0) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: - video = latents + # long video generation + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + overlap_history_frames = (overlap_history - 1) // 4 + 1 + n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + print(f"n_iter:{n_iter}") + output_video = None + for i in range(n_iter): + if output_video is not None: # i !=0 + prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = latent_length - finished_frame_num + base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) + if ar_step > 0 and self.transformer.enable_teacache: + num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + self.transformer.num_steps = num_steps + else: # i == 0 + base_num_frames_iter = base_num_frames + latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + latents = self.prepare_latents( + latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator + ) + latents = [latents] + if prefix_video is not None: + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowMatchUniPCMultistepScheduler( + num_train_timesteps=1000, shift=1, use_dynamic_shifting=False + ) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + self.transformer.to(self.device) + for i, timestep_i in enumerate(tqdm(step_matrix)): + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + )[0] + else: + noise_pred_cond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=prompt_embeds, + fps=fps_embeds, + )[0] + noise_pred_uncond = self.transformer( + torch.stack([latent_model_input[0]]), + t=timestep, + context=negative_prompt_embeds, + fps=fps_embeds, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[0][:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[0][:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + if self.offload: + self.transformer.cpu() + torch.cuda.empty_cache() + x0 = latents[0].unsqueeze(0) + videos = [self.vae.decode(x0)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video for video in output_video] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] # Offload all models self.maybe_free_model_hooks() From c9483b2ac2fb977367f01f75dc8a1cd7c9c35649 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 17:04:29 +0300 Subject: [PATCH 032/264] Remove unused dtype handling in `SkyReelsV2DiffusionForcingPipeline` to streamline video generation process. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 9a0adda188e4..898a25fe9aac 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -542,8 +542,7 @@ def __call__( causal_block_size = self.transformer.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - transformer_dtype = self.transformer.dtype - # with torch.cuda.amp.autocast(dtype=self.transformer.dtype), torch.no_grad(): + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # short video generation latent_shape = [16, latent_length, latent_height, latent_width] From f7fed01ce97d636753a1aea9b1f6e22b0b1b3b78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 17:25:08 +0300 Subject: [PATCH 033/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 898a25fe9aac..740c6771cc61 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -544,14 +544,12 @@ def __call__( fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # short video generation + # Short video generation latent_shape = [16, latent_length, latent_height, latent_width] latents = self.prepare_latents( latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator ) latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size @@ -611,22 +609,17 @@ def __call__( generator=generator, )[0] sample_schedulers_counter[idx] += 1 - if self.offload: - self.transformer.cpu() - torch.cuda.empty_cache() x0 = latents[0].unsqueeze(0) videos = self.vae.decode(x0) videos = (videos / 2 + 0.5).clamp(0, 1) videos = [video for video in videos] videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - videos = [video.cpu().numpy().astype(np.uint8) for video in videos] - return videos + video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: - # long video generation + # Long video generation base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length overlap_history_frames = (overlap_history - 1) // 4 + 1 n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - print(f"n_iter:{n_iter}") output_video = None for i in range(n_iter): if output_video is not None: # i !=0 @@ -732,9 +725,12 @@ def __call__( output_video = [(output_video / 2 + 0.5).clamp(0, 1)] output_video = [video for video in output_video] output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - output_video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] - if not return_dict: - return (output_video,) + # Offload all models + self.maybe_free_model_hooks() - return SkyReelsV2PipelineOutput(frames=output_video) + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) From 0e7b21da4b720a08273f35c20ea8d98c3000181c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 18:36:14 +0300 Subject: [PATCH 034/264] up --- .../skyreels_v2/pipeline_skyreels_v2.py | 2 +- .../pipeline_skyreels_v2_diffusion_forcing.py | 20 +++++++++++++------ ...eline_skyreels_v2_diffusion_forcing_i2v.py | 20 ++++++++++--------- .../skyreels_v2/pipeline_skyreels_v2_i2v.py | 6 +++++- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 4e64b7bab4e3..894c3c1f5ab6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -94,7 +94,7 @@ def prompt_clean(text): class SkyReelsV2Pipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for text-to-video generation using Wan. + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 740c6771cc61..2dace5c748f2 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -100,7 +100,7 @@ def retrieve_latents( class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for text-to-video generation using SkyReels-V2 with diffusion forcing. + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -460,13 +460,21 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `5.0`): - The shift of the flow. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + shift (`float`, *optional*, defaults to `1.0`): - Examples: + overlap_history (`int`, *optional*): + + addnoise_condition (`float`, *optional*, defaults to `0.0`): + + base_num_frames (`int`, *optional*, defaults to `97`): + ar_step (`int`, *optional*, defaults to `5`): + The step of the autoregressive steps. + causal_block_size (`int`, *optional*): + + fps (`int`, *optional*, defaults to `24`): + + Examples: Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 859573f7bd06..a395fd5cf00c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -98,8 +98,7 @@ def retrieve_latents( class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for video generation with diffusion forcing using SkyReels-V2. This pipeline supports two main tasks: - Text-to-Video (t2v) and Image-to-Video (i2v) + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -154,6 +153,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.image_processor = image_processor + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -195,6 +195,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_image def encode_image( self, image: PipelineImageInput, @@ -287,6 +288,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -341,6 +343,7 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, @@ -594,7 +597,6 @@ def __call__( batch_size = prompt_embeds.shape[0] # 3. Encode input prompt - #prompt_embeds = self.text_encoder.encode(prompt).to(self.transformer.dtype) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -612,6 +614,8 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # Encode image embedding + prefix_video = not None + predix_video_latent_length = 0 #prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) if image_embeds is None: image_embeds = self.encode_image(image, device) @@ -623,14 +627,13 @@ def __call__( timesteps = self.scheduler.timesteps if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # short video generation + # Short video generation latent_shape = [16, latent_length, latent_height, latent_width] latents = self.prepare_latents( latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator ) latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size @@ -700,7 +703,7 @@ def __call__( videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: - # long video generation + # Long video generation base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length overlap_history_frames = (overlap_history - 1) // 4 + 1 n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 @@ -728,8 +731,7 @@ def __call__( latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator ) latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index d6592b5d5b52..478753a69279 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -127,7 +127,7 @@ def retrieve_latents( class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for image-to-video generation using Wan. + Pipeline for image-to-video generation using SkyReels-V2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -182,6 +182,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.image_processor = image_processor + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_prompt def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -223,6 +224,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_image def encode_image( self, image: PipelineImageInput, @@ -315,6 +317,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -369,6 +372,7 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, From b3698d7ac1db3c350b390efdbd62ff47d3d93e74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 13 May 2025 18:51:39 +0300 Subject: [PATCH 035/264] Update references --- .../pipeline_skyreels_v2_diffusion_forcing.py | 12 +++---- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 36 +++++++++++++------ .../skyreels_v2/pipeline_skyreels_v2_i2v.py | 10 +++--- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2dace5c748f2..cdb7a2d2241f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -14,17 +14,15 @@ import html import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np -from tqdm import tqdm +from typing import Any, Callable, Dict, List, Optional, Union import ftfy -import PIL.Image +import numpy as np import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler @@ -150,7 +148,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index a395fd5cf00c..b916e01618cc 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -153,7 +153,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.image_processor = image_processor - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline._get_t5_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -195,7 +195,7 @@ def _get_t5_prompt_embeds( return prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_image + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image def encode_image( self, image: PipelineImageInput, @@ -206,7 +206,7 @@ def encode_image( image_embeds = self.image_encoder(**image, output_hidden_states=True) return image_embeds.hidden_states[-2] - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -288,7 +288,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.check_inputs + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs def check_inputs( self, prompt, @@ -343,7 +343,7 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.prepare_latents + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, @@ -356,6 +356,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -374,10 +375,16 @@ def prepare_latents( latents = latents.to(device=device, dtype=dtype) image = image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( @@ -403,7 +410,10 @@ def prepare_latents( mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - mask_lat_size[:, :, list(range(1, num_frames))] = 0 + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) @@ -455,6 +465,7 @@ def __call__( prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -618,7 +629,10 @@ def __call__( predix_video_latent_length = 0 #prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) if image_embeds is None: - image_embeds = self.encode_image(image, device) + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 478753a69279..832e6a8a73ec 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -182,7 +182,7 @@ def __init__( self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.image_processor = image_processor - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_prompt + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -224,7 +224,7 @@ def _get_t5_prompt_embeds( return prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.encode_image + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image def encode_image( self, image: PipelineImageInput, @@ -235,7 +235,7 @@ def encode_image( image_embeds = self.image_encoder(**image, output_hidden_states=True) return image_embeds.hidden_states[-2] - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -317,7 +317,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.check_inputs + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs def check_inputs( self, prompt, @@ -372,7 +372,7 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanPipeline.prepare_latents + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, From 7e0f0f5441e3d5474e2335b69218a87138a82e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 12:06:16 +0300 Subject: [PATCH 036/264] Add `generate_timestep_matrix` method to `SkyReelsV2DiffusionForcingPipeline` for proper timestep management in video generation. Refactor latent variable preparation and update handling for better clarity. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 355 ++++++++++++------ 1 file changed, 232 insertions(+), 123 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index cdb7a2d2241f..4bb47a3f9036 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -342,6 +342,82 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + @property def guidance_scale(self): return self._guidance_scale @@ -528,10 +604,6 @@ def __call__( device=device, ) - latent_height = height // 8 - latent_width = width // 8 - latent_length = (num_frames - 1) // 4 + 1 - transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: @@ -551,15 +623,25 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - latent_shape = [16, latent_length, latent_height, latent_width] + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( - latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, ) - latents = [latents] + latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) + sample_schedulers = [] for _ in range(latent_length): sample_scheduler = FlowMatchUniPCMultistepScheduler( @@ -568,59 +650,71 @@ def __call__( sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * latent_length - self.transformer.to(self.device) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: - noise_pred = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - else: - noise_pred_cond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred_uncond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=negative_prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + else: + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 - x0 = latents[0].unsqueeze(0) - videos = self.vae.decode(x0) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + x0 = latents.unsqueeze(0) + videos = self.vae.decode(x0) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length @@ -646,11 +740,17 @@ def __call__( base_num_frames_iter = base_num_frames latent_shape = [16, base_num_frames_iter, latent_height, latent_width] latents = self.prepare_latents( - latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, ) - latents = [latents] if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, @@ -659,6 +759,7 @@ def __call__( predix_video_latent_length, causal_block_size, ) + sample_schedulers = [] for _ in range(base_num_frames_iter): sample_scheduler = FlowMatchUniPCMultistepScheduler( @@ -667,71 +768,79 @@ def __call__( sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter - self.transformer.to(self.device) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: - noise_pred = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - else: - noise_pred_cond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred_uncond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=negative_prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + if not self.do_classifier_free_guidance: + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 - if self.offload: - self.transformer.cpu() - torch.cuda.empty_cache() - x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video for video in output_video] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + else: + noise_pred_cond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + x0 = latents.unsqueeze(0) + videos = [self.vae.decode(x0)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video for video in output_video] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] # Offload all models self.maybe_free_model_hooks() From 8c23208d353750c5da62c713889060e3c545dbf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 14:52:46 +0300 Subject: [PATCH 037/264] Remove training-related code --- .../transformers/transformer_skyreels_v2.py | 43 ------------------- 1 file changed, 43 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 687da6353dd4..04a38cb4157f 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -470,9 +470,6 @@ def __init__( nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6) ) - # TODO: Say: Initializing suggested by the original repo? - # self.init_weights() - def forward( self, hidden_states: torch.Tensor, @@ -766,43 +763,3 @@ def initialize_teacache( self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] self.ret_steps = 1 * 2 self.cutoff_steps = num_steps * 2 - 2 - - def init_weights(self): - r""" - Initialize model parameters using Xavier initialization. - """ - - # basic init - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - - # init embeddings - nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) - for m in self.text_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - for m in self.time_embedding.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - - if self.inject_sample_info: - nn.init.normal_(self.fps_embedding.weight, std=0.02) - - for m in self.fps_projection.modules(): - if isinstance(m, nn.Linear): - nn.init.normal_(m.weight, std=0.02) - - nn.init.zeros_(self.fps_projection[-1].weight) - nn.init.zeros_(self.fps_projection[-1].bias) - - # init output layer - nn.init.zeros_(self.head.head.weight) - - def zero_init_i2v_cross_attn(self): - print("zero init i2v cross attn") - for i in range(self.num_layers): - self.blocks[i].cross_attn.v_img.weight.data.zero_() - self.blocks[i].cross_attn.v_img.bias.data.zero_() From 1f8e268de62daa17b9d6760176d7071a62b2dab1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 17:42:23 +0300 Subject: [PATCH 038/264] Add gradient checkpointing support in `SkyReelsV2Transformer3DModel` for improved memory efficiency during training. --- .../models/transformers/transformer_skyreels_v2.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 04a38cb4157f..6093965b50ef 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -559,6 +559,12 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask + ) if self.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) From d85352188c0a488d29ba1824fb13b781013c7d8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 17:58:30 +0300 Subject: [PATCH 039/264] Refactor `SkyReelsV2TransformerBlock` and remove unused `Head` class. Update tensor handling in `SkyReelsV2Transformer3DModel` for improved dimensionality management. Clean up imports in `pipeline_skyreels_v2_diffusion_forcing.py` by removing `tqdm`. --- .../transformers/transformer_skyreels_v2.py | 27 ++++--------------- .../pipeline_skyreels_v2_diffusion_forcing.py | 1 - 2 files changed, 5 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 6093965b50ef..f9176a23c498 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -17,7 +17,6 @@ import numpy as np import torch -import torch.amp as amp import torch.nn as nn import torch.nn.functional as F from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention @@ -333,26 +332,6 @@ def set_ar_attention(self): self.attn1.processor.set_ar_attention() -class Head(nn.Module): - def forward(self, x, e): - r""" - Args: - x(Tensor): Shape [B, L1, C] - e(Tensor): Shape [B, C] - """ - with amp.autocast("cuda", dtype=torch.float32): - if e.dim() == 2: - modulation = self.modulation # 1, 2, dim - e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) - - elif e.dim() == 3: - modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim - e = (modulation + e.unsqueeze(1)).chunk(2, dim=1) - e = [ei.squeeze(1) for ei in e] - x = self.head(self.norm(x) * (1 + e[1]) + e[0]) - return x - - class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin): r""" A Transformer model for video-like data used in the Wan-based SkyReels-V2 model. @@ -654,7 +633,11 @@ def forward( hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask) # 5. Output norm, projection & unpatchify - shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + if temb.dim() == 2: + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + elif temb.dim() == 3: + shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1) + shift, scale = shift.squeeze(1), scale.squeeze(1) # Move the shift and scale tensors to the same device as hidden_states. # When using multi-GPU inference via accelerate these will be on the diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 4bb47a3f9036..ef4eec0901df 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -19,7 +19,6 @@ import ftfy import numpy as np import torch -from tqdm import tqdm from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback From 2b79584664b5dc00931d0a3dd32beda8c57f7b32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 18:04:55 +0300 Subject: [PATCH 040/264] Remove unused parameter `y` and associated documentation from `SkyReelsV2Transformer3DModel`. Clean up code for improved clarity and maintainability. --- .../transformers/transformer_skyreels_v2.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index f9176a23c498..f9a237fa775a 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -455,31 +455,10 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, - y: Optional[torch.Tensor] = None, fps: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - r""" - Args: - x (List[Tensor]): - List of input video tensors, each with shape [C_in, F, H, W] - t (Tensor): - Diffusion timesteps tensor of shape [B] - context (List[Tensor]): - List of text embeddings each with shape [L, C] - seq_len (`int`): - Maximum sequence length for positional encoding - clip_fea (Tensor, *optional*): - CLIP image features for image-to-video mode - y (List[Tensor], *optional*): - Conditional video inputs for image-to-video mode, same shape as x - - Returns: - List[Tensor]: - List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] - """ - if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -503,10 +482,6 @@ def forward( rotary_emb = self.rope(hidden_states) - # TODO: check here - if y is not None: - hidden_states = torch.cat([hidden_states, y], dim=1) - hidden_states = self.patch_embedding(hidden_states) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) From 600ced35ce3d45cd12a56933dc53e7e359dd4886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 19:33:53 +0300 Subject: [PATCH 041/264] Update context length calculation in `SkyReelsV2AttnProcessor2_0` to reflect the text encoder's context size. --- .../models/transformers/transformer_skyreels_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index f9a237fa775a..be033074ce50 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -33,7 +33,7 @@ from ..normalization import FP32LayerNorm -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") @@ -57,8 +57,8 @@ def __call__( ) -> torch.Tensor: encoder_hidden_states_img = None if attn.add_k_proj is not None: - # image_context_length is hardcoded for now like in the original code - image_context_length = 257 + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] encoder_hidden_states = encoder_hidden_states[:, image_context_length:] if encoder_hidden_states is None: From 586fe56e3750984f13f853882f05ea93a175e294 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 19:55:54 +0300 Subject: [PATCH 042/264] Fix comparison logic in `SkyReelsV2AttnProcessor2_0` to correctly determine self-attention by using `is` instead of `==` for encoder hidden states. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index be033074ce50..34317b938d52 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -104,7 +104,7 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states_img = hidden_states_img.type_as(query) if self._flag_ar_attention: - is_self_attention = encoder_hidden_states == hidden_states + is_self_attention = encoder_hidden_states is hidden_states hidden_states = F.scaled_dot_product_attention( query.to(torch.bfloat16) if is_self_attention else query, key.to(torch.bfloat16) if is_self_attention else key, From afcaf6e382b444b09d7b9e8c88ac9269260c8f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 14 May 2025 19:59:47 +0300 Subject: [PATCH 043/264] Remove unused `flex_attention` variable from `transformer_skyreels_v2.py` to clean up code and improve maintainability. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 34317b938d52..02791f914ba8 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +from torch.nn.attention.flex_attention import BlockMask, create_block_mask from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -35,8 +35,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -flex_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune") - class SkyReelsV2AttnProcessor2_0: def __init__(self): From 465df8c8320fe6fb58cfc549fa5ad8190f2ab1c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 15 May 2025 11:02:36 +0300 Subject: [PATCH 044/264] Updates SkyReelsV2 pipeline defaults and docs Adjusts default values for `overlap_history`, `addnoise_condition`, and `causal_block_size`. These changes aim to improve consistency and smoothness in long video generation. Enhances docstrings for several parameters to provide clearer usage guidance. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index ef4eec0901df..7b22de0733c5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -464,13 +464,13 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, # TODO: check - overlap_history: Optional[int] = None, - shift: float = 1.0, - addnoise_condition: float = 0.0, + max_sequence_length: int = 512, + overlap_history: Optional[int] = 17, + shift: float = 1.0, # TODO: check this + addnoise_condition: float = 20.0, base_num_frames: int = 97, ar_step: int = 5, - causal_block_size: Optional[int] = None, + causal_block_size: Optional[int] = 5, fps: int = 24, ): r""" @@ -534,17 +534,16 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `1.0`): - - overlap_history (`int`, *optional*): - - addnoise_condition (`float`, *optional*, defaults to `0.0`): - + overlap_history (`int`, *optional*, defaults to `17`): + Number of frames to overlap for smooth transitions in long videos + addnoise_condition (`float`, *optional*, defaults to `20`): + Improves consistency in long video generation base_num_frames (`int`, *optional*, defaults to `97`): - + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) ar_step (`int`, *optional*, defaults to `5`): - The step of the autoregressive steps. - causal_block_size (`int`, *optional*): - + Controls asynchronous inference (0 for synchronous mode) + causal_block_size (`int`, *optional*, defaults to `5`): + Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): Examples: From cad2d387bda0f26dd2e4536582eafc1cc531a267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 15 May 2025 18:38:49 +0300 Subject: [PATCH 045/264] Remove `enable_teacache` functionality from `SkyReelsV2Transformer3DModel` to streamline the model's processing logic. This change simplifies the code by eliminating conditional checks and related variables, enhancing maintainability and clarity. --- .../transformers/transformer_skyreels_v2.py | 118 +----------------- 1 file changed, 2 insertions(+), 116 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 02791f914ba8..51113df8f7ea 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -407,7 +407,6 @@ def __init__( self.num_frame_per_block = 1 self.flag_causal_attention = False - self.enable_teacache = False # 1. Patch & position embedding self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -535,75 +534,8 @@ def forward( timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() - if self.enable_teacache: - modulated_inp = timestep_proj if self.use_ref_steps else temb - # teacache - if self.cnt % 2 == 0: # even -> condition - self.is_even = True - if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: - should_calc_even = True - self.accumulated_rel_l1_distance_even = 0 - else: - rescale_func = np.poly1d(self.coefficients) - self.accumulated_rel_l1_distance_even += rescale_func( - ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()) - .cpu() - .item() - ) - if self.accumulated_rel_l1_distance_even < self.teacache_thresh: - should_calc_even = False - else: - should_calc_even = True - self.accumulated_rel_l1_distance_even = 0 - self.previous_e0_even = modulated_inp.clone() - - else: # odd -> unconditon - self.is_even = False - if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps: - should_calc_odd = True - self.accumulated_rel_l1_distance_odd = 0 - else: - rescale_func = np.poly1d(self.coefficients) - self.accumulated_rel_l1_distance_odd += rescale_func( - ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean()) - .cpu() - .item() - ) - if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: - should_calc_odd = False - else: - should_calc_odd = True - self.accumulated_rel_l1_distance_odd = 0 - self.previous_e0_odd = modulated_inp.clone() - - if self.enable_teacache: - if self.is_even: - if not should_calc_even: - hidden_states += self.previous_residual_even - else: - ori_hidden_states = hidden_states.clone() - for block in self.blocks: - hidden_states = block( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask - ) - self.previous_residual_even = hidden_states - ori_hidden_states - else: - if not should_calc_odd: - hidden_states += self.previous_residual_odd - else: - ori_hidden_states = hidden_states.clone() - for block in self.blocks: - hidden_states = block( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask - ) - self.previous_residual_odd = hidden_states - ori_hidden_states - - self.cnt += 1 - if self.cnt >= self.num_steps: - self.cnt = 0 - else: - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask) + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask) # 5. Output norm, projection & unpatchify if temb.dim() == 2: @@ -679,49 +611,3 @@ def attention_mask(b, h, q_idx, kv_idx): ) return block_mask - - def initialize_teacache( - self, enable_teacache=True, num_steps=25, teacache_thresh=0.15, use_ret_steps=False, ckpt_dir="" - ): - self.enable_teacache = enable_teacache - print("using teacache") - self.cnt = 0 - self.num_steps = num_steps - self.teacache_thresh = teacache_thresh - self.accumulated_rel_l1_distance_even = 0 - self.accumulated_rel_l1_distance_odd = 0 - self.previous_e0_even = None - self.previous_e0_odd = None - self.previous_residual_even = None - self.previous_residual_odd = None - self.use_ref_steps = use_ret_steps - if "I2V" in ckpt_dir: - if use_ret_steps: - if "540P" in ckpt_dir: - self.coefficients = [2.57151496e05, -3.54229917e04, 1.40286849e03, -1.35890334e01, 1.32517977e-01] - if "720P" in ckpt_dir: - self.coefficients = [8.10705460e03, 2.13393892e03, -3.72934672e02, 1.66203073e01, -4.17769401e-02] - self.ret_steps = 5 * 2 - self.cutoff_steps = num_steps * 2 - else: - if "540P" in ckpt_dir: - self.coefficients = [-3.02331670e02, 2.23948934e02, -5.25463970e01, 5.87348440e00, -2.01973289e-01] - if "720P" in ckpt_dir: - self.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683] - self.ret_steps = 1 * 2 - self.cutoff_steps = num_steps * 2 - 2 - else: - if use_ret_steps: - if "1.3B" in ckpt_dir: - self.coefficients = [-5.21862437e04, 9.23041404e03, -5.28275948e02, 1.36987616e01, -4.99875664e-02] - if "14B" in ckpt_dir: - self.coefficients = [-3.03318725e05, 4.90537029e04, -2.65530556e03, 5.87365115e01, -3.15583525e-01] - self.ret_steps = 5 * 2 - self.cutoff_steps = num_steps * 2 - else: - if "1.3B" in ckpt_dir: - self.coefficients = [2.39676752e03, -1.31110545e03, 2.01331979e02, -8.29855975e00, 1.37887774e-01] - if "14B" in ckpt_dir: - self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] - self.ret_steps = 1 * 2 - self.cutoff_steps = num_steps * 2 - 2 From 1fcdf9838a1360274b17d9e67c143e5035b2584a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 15 May 2025 19:07:02 +0300 Subject: [PATCH 046/264] Refactor `SkyReelsV2Transformer3DModel` to use configuration parameters for `num_frame_per_block` and `flag_causal_attention`. This change enhances flexibility by allowing these settings to be adjusted via the model's configuration, improving code clarity and maintainability. --- .../transformers/transformer_skyreels_v2.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 51113df8f7ea..d6c5064d0435 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -399,14 +399,14 @@ def __init__( rope_max_seq_len: int = 1024, pos_embed_seq_len: Optional[int] = None, inject_sample_info: bool = False, + num_frame_per_block: int = 1, + flag_causal_attention: bool = False, ) -> None: super().__init__() inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - self.num_frame_per_block = 1 - self.flag_causal_attention = False # 1. Patch & position embedding self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -427,7 +427,7 @@ def __init__( self.blocks = nn.ModuleList( [ SkyReelsV2TransformerBlock( - inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim=inner_dim + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim ) for _ in range(num_layers) ] @@ -482,11 +482,11 @@ def forward( hidden_states = self.patch_embedding(hidden_states) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) - if self.flag_causal_attention: + if self.config.flag_causal_attention: frame_num, height, width = grid_sizes - block_num = frame_num // self.num_frame_per_block + block_num = frame_num // self.config.num_frame_per_block range_tensor = torch.arange(block_num, device=hidden_states.device).view(-1, 1) - range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten() + range_tensor = range_tensor.repeat(1, self.config.num_frame_per_block).flatten() causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1) causal_mask = causal_mask.repeat(1, height, width, 1, height, width) @@ -516,7 +516,7 @@ def forward( hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask ) - if self.inject_sample_info: + if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() @@ -570,8 +570,8 @@ def forward( return Transformer2DModelOutput(sample=output) def set_ar_attention(self, causal_block_size): - self.num_frame_per_block = causal_block_size - self.flag_causal_attention = True + self.config.num_frame_per_block = causal_block_size + self.config.flag_causal_attention = True for block in self.blocks: block.set_ar_attention() From 6d577255b551dd5b7f93b7d3342f499d15441cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 15 May 2025 19:54:48 +0300 Subject: [PATCH 047/264] Remove unused import of `numpy` and clean up whitespace in `transformer_skyreels_v2.py` to enhance code clarity. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d6c5064d0435..d38ac6376dc6 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -15,7 +15,6 @@ import math from typing import Any, Dict, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -407,7 +406,6 @@ def __init__( inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels - # 1. Patch & position embedding self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) From c4cec0411c708ec8de12b6ea7c1a68d337be1f70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 15 May 2025 21:40:30 +0300 Subject: [PATCH 048/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to improve error handling and streamline the denoising loop. Replace assertion with a ValueError for `ar_step` validation, enhance sample schedulers preparation, and simplify noise prediction logic. Additionally, add callback functionality for step-end processing and ensure proper handling of latents and video output. These changes enhance code clarity, maintainability, and functionality. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 154 ++++++++++++------ 1 file changed, 101 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 7b22de0733c5..1f2d282753c1 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -18,7 +18,9 @@ import ftfy import numpy as np +import math import torch +from copy import deepcopy from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -360,7 +362,8 @@ def generate_timestep_matrix( infer_step_num = len(step_template) gen_block = base_num_frames_block min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" + if ar_step < min_ar_step: + raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) step_template = torch.cat( [ @@ -608,7 +611,6 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps prefix_video = None @@ -640,15 +642,19 @@ def __call__( latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) - sample_schedulers = [] - for _ in range(latent_length): - sample_scheduler = FlowMatchUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) + # Prepare sample schedulers + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [self.scheduler] + for _ in range(latent_length - 1): + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * latent_length + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, timestep_i in enumerate(step_matrix): if self.interrupt: @@ -670,25 +676,16 @@ def __call__( * noise_factor ) timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - else: - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred_uncond = self.transformer( + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, @@ -696,7 +693,8 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( @@ -707,6 +705,24 @@ def __call__( generator=generator, )[0] sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + x0 = latents.unsqueeze(0) videos = self.vae.decode(x0) videos = (videos / 2 + 0.5).clamp(0, 1) @@ -731,7 +747,7 @@ def __call__( finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - if ar_step > 0 and self.transformer.enable_teacache: + if ar_step > 0: num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step self.transformer.num_steps = num_steps else: # i == 0 @@ -758,15 +774,19 @@ def __call__( causal_block_size, ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowMatchUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) + # Prepare sample schedulers + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [self.scheduler] + for _ in range(base_num_frames_iter - 1): + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, timestep_i in enumerate(step_matrix): if self.interrupt: @@ -790,25 +810,17 @@ def __call__( * noise_factor ) timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - else: - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred_uncond = self.transformer( + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, @@ -816,7 +828,7 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( @@ -827,6 +839,24 @@ def __call__( generator=generator, )[0] sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + x0 = latents.unsqueeze(0) videos = [self.vae.decode(x0)[0]] if output_video is None: @@ -840,6 +870,24 @@ def __call__( output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + # Offload all models self.maybe_free_model_hooks() From 6a85ba11d313535e9aa3e60a42ba9fc1c6027c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 13:17:50 +0300 Subject: [PATCH 049/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to enhance sample scheduler preparation and timestep management. Consolidate the logic for setting timesteps and generating the timestep matrix, improving clarity and maintainability. Update callback handling to ensure correct timestep usage during the denoising loop. These changes streamline the video generation process and improve overall code structure. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 85 ++++++++++--------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 1f2d282753c1..b4b370810f15 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -610,9 +610,6 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # 4. Prepare timesteps - timesteps = self.scheduler.timesteps - prefix_video = None predix_video_latent_length = 0 @@ -623,6 +620,22 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + sample_schedulers = [self.scheduler] + for _ in range(latent_length - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + ) + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( @@ -636,20 +649,6 @@ def __call__( generator, latents, ) - latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size - ) - - # Prepare sample schedulers - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers = [self.scheduler] - for _ in range(latent_length - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -710,14 +709,14 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: @@ -752,7 +751,27 @@ def __call__( self.transformer.num_steps = num_steps else: # i == 0 base_num_frames_iter = base_num_frames - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + sample_schedulers = [self.scheduler] + for _ in range(base_num_frames_iter - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -762,30 +781,14 @@ def __call__( torch.float32, device, generator, + latents, ) if prefix_video is not None: latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - - # Prepare sample schedulers - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers = [self.scheduler] - for _ in range(base_num_frames_iter - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, timestep_i in enumerate(step_matrix): @@ -844,14 +847,14 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: From 76af29bda0aeac73e8d98f33e276756e9c306d97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 13:20:35 +0300 Subject: [PATCH 050/264] update template for df_i2v --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 645 ++++++++++-------- 1 file changed, 354 insertions(+), 291 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index b916e01618cc..57c4b782929b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -13,16 +13,17 @@ # limitations under the License. import html +import math import re -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Union import ftfy -import PIL.Image +import numpy as np import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler @@ -40,7 +41,7 @@ else: XLA_AVAILABLE = False -logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_ftfy_available(): import ftfy @@ -51,11 +52,11 @@ ```py >>> import torch >>> import PIL.Image - >>> from diffusers import SkyReelsV2DiffusionForcingImageToVideoPipeline + >>> from diffusers import SkyReelsV2DiffusionForcingPipeline >>> from diffusers.utils import export_to_video, load_image >>> # Load the pipeline - >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( + >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 ... ) >>> pipe = pipe.to("cuda") @@ -96,9 +97,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -110,11 +111,6 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraL text_encoder ([`UMT5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - image_encoder ([`CLIPVisionModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically - the - [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) - variant. transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): @@ -123,15 +119,13 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraL Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, @@ -142,23 +136,20 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, - image_processor=image_processor, ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.image_processor = image_processor - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -195,18 +186,7 @@ def _get_t5_prompt_embeds( return prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image - def encode_image( - self, - image: PipelineImageInput, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - image = self.image_processor(images=image, return_tensors="pt").to(device) - image_embeds = self.image_encoder(**image, output_hidden_states=True) - return image_embeds.hidden_states[-2] - - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -288,30 +268,17 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.check_inputs + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, negative_prompt, - image, height, width, prompt_embeds=None, negative_prompt_embeds=None, - image_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if image is not None and image_embeds is not None: - raise ValueError( - f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" - " only forward one of the two." - ) - if image is None and image_embeds is None: - raise ValueError( - "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." - ) - if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): - raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -343,10 +310,9 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.prepare_latents + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, - image: PipelineImageInput, batch_size: int, num_channels_latents: int = 16, height: int = 480, @@ -356,72 +322,103 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - latent_height = height // self.vae_scale_factor_spatial - latent_width = width // self.vae_scale_factor_spatial + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) - shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents - image = image.unsqueeze(2) - if last_image is None: - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - else: - last_image = last_image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], - dim=2, - ) - video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + if ar_step < min_ar_step: + raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - - latent_condition = latent_condition.to(dtype) - latent_condition = (latent_condition - latents_mean) * latents_std - - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - - if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(latent_condition.device) - - return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while torch.all(pre_row >= (num_iterations - 1)) == False: + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval @property def guidance_scale(self): @@ -429,7 +426,7 @@ def guidance_scale(self): @property def do_classifier_free_guidance(self): - return self._guidance_scale > 1 + return self._guidance_scale > 1.0 @property def num_timesteps(self): @@ -451,7 +448,6 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: PipelineImageInput, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, height: int = 480, @@ -464,8 +460,6 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -473,21 +467,19 @@ def __call__( Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], - max_sequence_length: int = 512, # TODO: check - overlap_history: Optional[int] = None, - shift: float = 1.0, - addnoise_condition: float = 0.0, + max_sequence_length: int = 512, + overlap_history: Optional[int] = 17, + shift: float = 1.0, # TODO: check this + addnoise_condition: float = 20.0, base_num_frames: int = 97, ar_step: int = 5, - causal_block_size: Optional[int] = None, + causal_block_size: Optional[int] = 5, fps: int = 24, ): r""" The call function to the pipeline for generation. Args: - image (`PipelineImageInput`): - The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -525,9 +517,6 @@ def __call__( negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `negative_prompt` input argument. - image_embeds (`torch.Tensor`, *optional*): - Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, - image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -547,12 +536,20 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `5.0`): - The shift of the flow. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. - Examples: + shift (`float`, *optional*, defaults to `1.0`): + overlap_history (`int`, *optional*, defaults to `17`): + Number of frames to overlap for smooth transitions in long videos + addnoise_condition (`float`, *optional*, defaults to `20`): + Improves consistency in long video generation + base_num_frames (`int`, *optional*, defaults to `97`): + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) + ar_step (`int`, *optional*, defaults to `5`): + Controls asynchronous inference (0 for synchronous mode) + causal_block_size (`int`, *optional*, defaults to `5`): + Recommended when using asynchronous inference (--ar_step > 0) + fps (`int`, *optional*, defaults to `24`): + Examples: Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned @@ -567,12 +564,10 @@ def __call__( self.check_inputs( prompt, negative_prompt, - image, height, width, prompt_embeds, negative_prompt_embeds, - image_embeds, callback_on_step_end_tensor_inputs, ) @@ -590,15 +585,6 @@ def __call__( device = self._execution_device - latent_height = height // 8 - latent_width = width // 8 - latent_length = (num_frames - 1) // 4 + 1 - - if causal_block_size is None: - causal_block_size = self.transformer.num_frame_per_block - fps_embeds = [fps] * prompt_embeds.shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -624,104 +610,129 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - # Encode image embedding - prefix_video = not None + prefix_video = None predix_video_latent_length = 0 - #prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames) - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps + + if causal_block_size is None: + causal_block_size = self.transformer.num_frame_per_block + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - latent_shape = [16, latent_length, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator - ) - latents = [latents] - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + sample_schedulers = [self.scheduler] + for _ in range(latent_length - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) - sample_schedulers = [] - for _ in range(latent_length): - sample_scheduler = FlowMatchUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - self.transformer.to(self.device) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition noise_pred = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - else: - noise_pred_cond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, )[0] - noise_pred_uncond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=negative_prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 - if self.offload: - self.transformer.cpu() - torch.cuda.empty_cache() - x0 = latents[0].unsqueeze(0) - videos = self.vae.decode(x0) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + x0 = latents.unsqueeze(0) + videos = self.vae.decode(x0) + videos = (videos / 2 + 0.5).clamp(0, 1) + videos = [video for video in videos] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length overlap_history_frames = (overlap_history - 1) // 4 + 1 n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - print(f"n_iter:{n_iter}") output_video = None for i in range(n_iter): if output_video is not None: # i !=0 @@ -735,17 +746,21 @@ def __call__( finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - if ar_step > 0 and self.transformer.enable_teacache: + if ar_step > 0: num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step self.transformer.num_steps = num_steps else: # i == 0 base_num_frames_iter = base_num_frames - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=transformer_dtype, device=prompt_embeds.device, generator=generator - ) - latents = [latents] - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + sample_schedulers = [self.scheduler] + for _ in range(base_num_frames_iter - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, @@ -754,79 +769,127 @@ def __call__( predix_video_latent_length, causal_block_size, ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowMatchUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - self.transformer.to(self.device) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if prefix_video is not None: + latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - if not self.do_classifier_free_guidance: + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + noise_pred = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - else: - noise_pred_cond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=prompt_embeds, - fps=fps_embeds, - )[0] - noise_pred_uncond = self.transformer( - torch.stack([latent_model_input[0]]), - t=timestep, - context=negative_prompt_embeds, + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, )[0] - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 - if self.offload: - self.transformer.cpu() - torch.cuda.empty_cache() - x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video for video in output_video] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + x0 = latents.unsqueeze(0) + videos = [self.vae.decode(x0)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video for video in output_video] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents # Offload all models self.maybe_free_model_hooks() From 81206ce1acf82f13e5894a379419b19451dd91cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 13:26:17 +0300 Subject: [PATCH 051/264] style --- .../pipeline_skyreels_v2_diffusion_forcing.py | 45 ++++++++++++------- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 41 +++++++++++------ 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index b4b370810f15..c3e5afb1f64f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -13,14 +13,14 @@ # limitations under the License. import html +import math import re +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy import numpy as np -import math import torch -from copy import deepcopy from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -376,7 +376,7 @@ def generate_timestep_matrix( if num_pre_ready > 0: pre_row[: num_pre_ready // casual_block_size] = num_iterations - while torch.all(pre_row >= (num_iterations - 1)) == False: + while not torch.all(pre_row >= (num_iterations - 1)): new_row = torch.zeros(num_frames_block, dtype=torch.long) for i in range(num_frames_block): if i == 0 or pre_row[i - 1] >= ( @@ -549,8 +549,7 @@ def __call__( Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): - Examples: - Returns: + Examples: Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s @@ -665,12 +664,15 @@ def __call__( valid_interval_i = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval_i timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) * noise_factor ) @@ -716,7 +718,9 @@ def __call__( negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: @@ -724,8 +728,7 @@ def __call__( x0 = latents.unsqueeze(0) videos = self.vae.decode(x0) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] + videos = [(videos / 2 + 0.5).clamp(0, 1)] videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: @@ -747,7 +750,10 @@ def __call__( left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) if ar_step > 0: - num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + num_steps = ( + num_inference_steps + + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + ) self.transformer.num_steps = num_steps else: # i == 0 base_num_frames_iter = base_num_frames @@ -800,7 +806,9 @@ def __call__( valid_interval_i = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval_i timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition @@ -812,7 +820,9 @@ def __call__( ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:predix_video_latent_length] = ( + timestep_for_noised_condition + ) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -851,10 +861,14 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: @@ -869,7 +883,6 @@ def __call__( [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 ) # c, f, h, w output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video for video in output_video] output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] video = [video.cpu().numpy().astype(np.uint8) for video in output_video] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 57c4b782929b..c3e5afb1f64f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -376,7 +376,7 @@ def generate_timestep_matrix( if num_pre_ready > 0: pre_row[: num_pre_ready // casual_block_size] = num_iterations - while torch.all(pre_row >= (num_iterations - 1)) == False: + while not torch.all(pre_row >= (num_iterations - 1)): new_row = torch.zeros(num_frames_block, dtype=torch.long) for i in range(num_frames_block): if i == 0 or pre_row[i - 1] >= ( @@ -549,8 +549,7 @@ def __call__( Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): - Examples: - Returns: + Examples: Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s @@ -665,12 +664,15 @@ def __call__( valid_interval_i = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval_i timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) * noise_factor ) @@ -716,7 +718,9 @@ def __call__( negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided - if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: @@ -724,8 +728,7 @@ def __call__( x0 = latents.unsqueeze(0) videos = self.vae.decode(x0) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] + videos = [(videos / 2 + 0.5).clamp(0, 1)] videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: @@ -747,7 +750,10 @@ def __call__( left_frame_num = latent_length - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) if ar_step > 0: - num_steps = num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + num_steps = ( + num_inference_steps + + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + ) self.transformer.num_steps = num_steps else: # i == 0 base_num_frames_iter = base_num_frames @@ -800,7 +806,9 @@ def __call__( valid_interval_i = valid_interval[i] valid_interval_start, valid_interval_end = valid_interval_i timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition @@ -812,7 +820,9 @@ def __call__( ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:predix_video_latent_length] = ( + timestep_for_noised_condition + ) noise_pred = self.transformer( hidden_states=latent_model_input, @@ -851,10 +861,14 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) # call the callback, if provided - if i == len(step_matrix) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if XLA_AVAILABLE: @@ -869,7 +883,6 @@ def __call__( [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 ) # c, f, h, w output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video for video in output_video] output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] video = [video.cpu().numpy().astype(np.uint8) for video in output_video] From 906b6f5de2205ede139ea730e4d9c1ddf223e49c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 14:26:44 +0300 Subject: [PATCH 052/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to improve the handling of latent frame calculations and sample scheduler preparation. Update logic for determining `num_latent_frames` and streamline the setting of timesteps across both short and long video generation paths. These changes enhance code clarity and maintainability while ensuring correct processing of video generation parameters. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 30 +++++++++---------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index c3e5afb1f64f..957b52aeb96d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -611,6 +611,8 @@ def __call__( prefix_video = None predix_video_latent_length = 0 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else num_latent_frames if causal_block_size is None: causal_block_size = self.transformer.num_frame_per_block @@ -619,20 +621,17 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 sample_schedulers = [self.scheduler] - for _ in range(latent_length - 1): + for _ in range(num_latent_frames - 1): sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size ) # 5. Prepare latent variables @@ -650,8 +649,8 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, timestep_i in enumerate(step_matrix): @@ -733,9 +732,8 @@ def __call__( video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 output_video = None for i in range(n_iter): if output_video is not None: # i !=0 @@ -747,7 +745,7 @@ def __call__( prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] predix_video_latent_length = prefix_video[0].shape[1] finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num + left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) if ar_step > 0: num_steps = ( @@ -759,12 +757,12 @@ def __call__( base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - sample_schedulers = [self.scheduler] + sample_schedulers = [deepcopy(self.scheduler)] for _ in range(base_num_frames_iter - 1): sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( From e2391b6373f9fcd25688d868ae2f5c34ddb9dd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 15:46:39 +0300 Subject: [PATCH 053/264] Add newly released `SkyReelsV2DiffusionForcingVideoToVideoPipeline` template --- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 927 ++++++++++++++++++ 1 file changed, 927 insertions(+) create mode 100644 src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py new file mode 100644 index 000000000000..9eea5e24b13b --- /dev/null +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -0,0 +1,927 @@ +# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import math +import re +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Union + +import ftfy +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, UMT5EncoderModel + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import WanLoraLoaderMixin +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel +from ...schedulers import FlowMatchUniPCMultistepScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import SkyReelsV2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """\ + Examples: + ```py + >>> import torch + >>> import PIL.Image + >>> from diffusers import SkyReelsV2DiffusionForcingPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> # Load the pipeline + >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> # TODO + ``` +""" + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): + """ + Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a specific device, etc.). + + Args: + tokenizer ([`AutoTokenizer`]): + Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), + specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + text_encoder ([`UMT5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + transformer ([`SkyReelsV2Transformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchUniPCMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLWan`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + transformer: SkyReelsV2Transformer3DModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchUniPCMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def generate_timestep_matrix( + self, + num_frames, + step_template, + base_num_frames, + ar_step=5, + num_pre_ready=0, + casual_block_size=1, + shrink_interval_with_mask=False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: + step_matrix, step_index = [], [] + update_mask, valid_interval = [], [] + num_iterations = len(step_template) + 1 + num_frames_block = num_frames // casual_block_size + base_num_frames_block = base_num_frames // casual_block_size + if base_num_frames_block < num_frames_block: + infer_step_num = len(step_template) + gen_block = base_num_frames_block + min_ar_step = infer_step_num / gen_block + if ar_step < min_ar_step: + raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( + [ + torch.tensor([999], dtype=torch.int64, device=step_template.device), + step_template.long(), + torch.tensor([0], dtype=torch.int64, device=step_template.device), + ] + ) # to handle the counter in row works starting from 1 + pre_row = torch.zeros(num_frames_block, dtype=torch.long) + if num_pre_ready > 0: + pre_row[: num_pre_ready // casual_block_size] = num_iterations + + while not torch.all(pre_row >= (num_iterations - 1)): + new_row = torch.zeros(num_frames_block, dtype=torch.long) + for i in range(num_frames_block): + if i == 0 or pre_row[i - 1] >= ( + num_iterations - 1 + ): # the first frame or the last frame is completely denoised + new_row[i] = pre_row[i] + 1 + else: + new_row[i] = new_row[i - 1] - ar_step + new_row = new_row.clamp(0, num_iterations) + + update_mask.append( + (new_row != pre_row) & (new_row != num_iterations) + ) # False: no need to update, True: need to update + step_index.append(new_row) + step_matrix.append(step_template[new_row]) + pre_row = new_row + + # for long video we split into several sequences, base_num_frames is set to the model max length (for training) + terminal_flag = base_num_frames_block + if shrink_interval_with_mask: + idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + update_mask = update_mask[0] + update_mask_idx = idx_sequence[update_mask] + last_update_idx = update_mask_idx[-1].item() + terminal_flag = last_update_idx + 1 + # for i in range(0, len(update_mask)): + for curr_mask in update_mask: + if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + terminal_flag += 1 + valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + + step_update_mask = torch.stack(update_mask, dim=0) + step_index = torch.stack(step_index, dim=0) + step_matrix = torch.stack(step_matrix, dim=0) + + if casual_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() + valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + + return step_matrix, step_index, step_update_mask, valid_interval + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video: List[Image.Image] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 97, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + overlap_history: Optional[int] = 17, + shift: float = 1.0, # TODO: check this + addnoise_condition: float = 20.0, + base_num_frames: int = 97, + ar_step: int = 5, + causal_block_size: Optional[int] = 5, + fps: int = 24, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, defaults to `480`): + The height of the generated video. + width (`int`, defaults to `832`): + The width of the generated video. + num_frames (`int`, defaults to `97`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length of the prompt. + shift (`float`, *optional*, defaults to `1.0`): + overlap_history (`int`, *optional*, defaults to `17`): + Number of frames to overlap for smooth transitions in long videos + addnoise_condition (`float`, *optional*, defaults to `20`): + Improves consistency in long video generation + base_num_frames (`int`, *optional*, defaults to `97`): + 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) + ar_step (`int`, *optional*, defaults to `5`): + Controls asynchronous inference (0 for synchronous mode) + causal_block_size (`int`, *optional*, defaults to `5`): + Recommended when using asynchronous inference (--ar_step > 0) + fps (`int`, *optional*, defaults to `24`): + + Examples: Returns: + [`~SkyReelsV2PipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + video, + latents, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + prefix_video = None + predix_video_latent_length = 0 + + if causal_block_size is None: + causal_block_size = self.transformer.num_frame_per_block + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: + # Short video generation + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + sample_schedulers = [self.scheduler] + for _ in range(latent_length - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * latent_length + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + video, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + latent_timestep, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + x0 = latents.unsqueeze(0) + videos = self.vae.decode(x0) + videos = [(videos / 2 + 0.5).clamp(0, 1)] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + video = [video.cpu().numpy().astype(np.uint8) for video in videos] + else: + # Long video generation + base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + overlap_history_frames = (overlap_history - 1) // 4 + 1 + n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + output_video = None + for i in range(n_iter): + if output_video is not None: # i !=0 + prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) + prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] + if prefix_video[0].shape[1] % causal_block_size != 0: + truncate_len = prefix_video[0].shape[1] % causal_block_size + print("the length of prefix video is truncated for the casual block size alignment.") + prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] + predix_video_latent_length = prefix_video[0].shape[1] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = latent_length - finished_frame_num + base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) + if ar_step > 0: + num_steps = ( + num_inference_steps + + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step + ) + self.transformer.num_steps = num_steps + else: # i == 0 + base_num_frames_iter = base_num_frames + + # 4. Prepare sample schedulers and timestep matrix + self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + timesteps = self.scheduler.timesteps + sample_schedulers = [self.scheduler] + for _ in range(base_num_frames_iter - 1): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers.append(sample_scheduler) + sample_schedulers_counter = [0] * base_num_frames_iter + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + base_num_frames_iter, + timesteps, + base_num_frames_iter, + ar_step, + predix_video_latent_length, + causal_block_size, + ) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + if prefix_video is not None: + latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + + # 6. Denoising loop + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep_i in enumerate(step_matrix): + if self.interrupt: + continue + + self._current_timestep = timestep_i + update_mask_i = step_update_mask[i] + valid_interval_i = valid_interval[i] + valid_interval_start, valid_interval_end = valid_interval_i + timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + latent_model_input = ( + latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + timestep_for_noised_condition = addnoise_condition + latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + * (1.0 - noise_factor) + + torch.randn_like( + latent_model_input[:, valid_interval_start:predix_video_latent_length] + ) + * noise_factor + ) + timestep[:, valid_interval_start:predix_video_latent_length] = ( + timestep_for_noised_condition + ) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, idx] = sample_schedulers[idx].step( + noise_pred[:, idx - valid_interval_start], + timestep_i[idx], + latents[:, idx], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + x0 = latents.unsqueeze(0) + videos = [self.vae.decode(x0)[0]] + if output_video is None: + output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + else: + output_video = torch.cat( + [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + ) # c, f, h, w + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return SkyReelsV2PipelineOutput(frames=video) From 245534f0b4a7e5933588bec34338a97cb3d7b057 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 15:47:07 +0300 Subject: [PATCH 054/264] up df_i2v --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index c3e5afb1f64f..5e96e4cb3264 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler @@ -97,9 +98,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -448,7 +449,8 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, @@ -460,6 +462,8 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -563,10 +567,12 @@ def __call__( self.check_inputs( prompt, negative_prompt, + image, height, width, prompt_embeds, negative_prompt_embeds, + image_embeds, callback_on_step_end_tensor_inputs, ) @@ -604,7 +610,20 @@ def __call__( device=device, ) + # Encode image embedding transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) @@ -636,8 +655,14 @@ def __call__( ) # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -647,6 +672,7 @@ def __call__( device, generator, latents, + last_image, ) # 6. Denoising loop From aaa8a8b9e726f297c9877e0c33241734aeb1c928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 16 May 2025 16:47:31 +0300 Subject: [PATCH 055/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to improve the handling of latent tensors and video output generation. Update logic for processing latents and ensure correct video formatting, enhancing clarity and maintainability in the video generation workflow. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 957b52aeb96d..baacfef05fb3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -725,11 +725,11 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = self.vae.decode(x0) - videos = [(videos / 2 + 0.5).clamp(0, 1)] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] + x0 = latents.unsqueeze(0) + videos = self.vae.decode(x0) + videos = [(videos / 2 + 0.5).clamp(0, 1)] + videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] + video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation overlap_history_frames = (overlap_history - 1) // 4 + 1 @@ -872,17 +872,31 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = [self.vae.decode(x0)[0]] + latents = latents.unsqueeze(0) + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents, return_dict=False)[0] if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + output_video = videos # c, f, h, w else: output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 + [output_video, videos[:, overlap_history:]], 1 ) # c, f, h, w - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + else: + output_video = latents + + output_video = [(output_video / 2 + 0.5).clamp(0, 1)] + output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] + video = [video.cpu().numpy().astype(np.uint8) for video in output_video] self._current_timestep = None From ca3f7bd5ecc6eaa5ad53bc23ed315a96201584fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 18:10:59 +0300 Subject: [PATCH 056/264] Integrate video decoding in pipeline --- .../pipeline_skyreels_v2_diffusion_forcing.py | 49 ++++++++----------- 1 file changed, 21 insertions(+), 28 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index baacfef05fb3..c37782e8be8a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -725,19 +725,15 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = self.vae.decode(x0) - videos = [(videos / 2 + 0.5).clamp(0, 1)] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] + latents = latents.unsqueeze(0) else: # Long video generation overlap_history_frames = (overlap_history - 1) // 4 + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - output_video = None + video = None for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) + if video is not None: # i !=0 + prefix_video = video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] if prefix_video[0].shape[1] % causal_block_size != 0: truncate_len = prefix_video[0].shape[1] % causal_block_size @@ -885,33 +881,30 @@ def __call__( ) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents, return_dict=False)[0] - if output_video is None: - output_video = videos # c, f, h, w + if video is None: + video = videos # c, f, h, w else: - output_video = torch.cat( - [output_video, videos[:, overlap_history:]], 1 + video = torch.cat( + [video, videos[:, overlap_history:]], 1 ) # c, f, h, w else: - output_video = latents - - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + video = latents self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + if overlap_history is None: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From b4e26fd8e8c1f011600ccec3bacc52e52f4d87c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 18:35:34 +0300 Subject: [PATCH 057/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index c37782e8be8a..2294ce6a9d5d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import ftfy -import numpy as np import torch from transformers import AutoTokenizer, UMT5EncoderModel @@ -731,6 +730,14 @@ def __call__( overlap_history_frames = (overlap_history - 1) // 4 + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) for i in range(n_iter): if video is not None: # i !=0 prefix_video = video[:, -overlap_history:].to(prompt_embeds.device) @@ -871,22 +878,13 @@ def __call__( latents = latents.unsqueeze(0) if not output_type == "latent": latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) latents = latents / latents_std + latents_mean + # TODO: Or collect all latents and decode at once in the end? videos = self.vae.decode(latents, return_dict=False)[0] if video is None: video = videos # c, f, h, w else: - video = torch.cat( - [video, videos[:, overlap_history:]], 1 - ) # c, f, h, w + video = torch.cat([video, videos[:, overlap_history:]], 1) # c, f, h, w else: video = latents @@ -900,9 +898,9 @@ def __call__( .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) From c3bcd1dce4d21c55296b82cdcf7f56cae3b78c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 19:27:38 +0300 Subject: [PATCH 058/264] Fix variable name typo in `SkyReelsV2DiffusionForcingPipeline` from `predix_video_latent_length` to `prefix_video_latent_length` for consistency. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2294ce6a9d5d..0f375e911717 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -609,7 +609,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) prefix_video = None - predix_video_latent_length = 0 + prefix_video_latent_length = 0 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else num_latent_frames @@ -630,7 +630,7 @@ def __call__( sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size ) # 5. Prepare latent variables @@ -665,16 +665,16 @@ def __call__( latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + + torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length]) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, @@ -739,14 +739,14 @@ def __call__( latents.device, latents.dtype ) for i in range(n_iter): - if video is not None: # i !=0 + if video is not None: prefix_video = video[:, -overlap_history:].to(prompt_embeds.device) prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] if prefix_video[0].shape[1] % causal_block_size != 0: truncate_len = prefix_video[0].shape[1] % causal_block_size print("the length of prefix video is truncated for the casual block size alignment.") prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] + prefix_video_latent_length = prefix_video[0].shape[1] finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) @@ -756,7 +756,7 @@ def __call__( + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step ) self.transformer.num_steps = num_steps - else: # i == 0 + else: base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix @@ -773,7 +773,7 @@ def __call__( timesteps, base_num_frames_iter, ar_step, - predix_video_latent_length, + prefix_video_latent_length, causal_block_size, ) @@ -791,7 +791,7 @@ def __call__( latents, ) if prefix_video is not None: - latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + latents[:, :prefix_video_latent_length] = prefix_video[0].to(transformer_dtype) # 6. Denoising loop num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order @@ -810,18 +810,18 @@ def __call__( latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( + latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, valid_interval_start:prefix_video_latent_length] ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = ( + timestep[:, valid_interval_start:prefix_video_latent_length] = ( timestep_for_noised_condition ) @@ -879,7 +879,6 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) latents = latents / latents_std + latents_mean - # TODO: Or collect all latents and decode at once in the end? videos = self.vae.decode(latents, return_dict=False)[0] if video is None: video = videos # c, f, h, w From c9bea144eef82728909d64f1db43aeafb284c328 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 19:44:51 +0300 Subject: [PATCH 059/264] Fix variable name from `casual_block_size` to `causal_block_size` for consistency and update related logic in `SkyReelsV2DiffusionForcingPipeline`. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 0f375e911717..7451aeadeaf7 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -349,21 +349,21 @@ def generate_timestep_matrix( base_num_frames, ar_step=5, num_pre_ready=0, - casual_block_size=1, + causal_block_size=1, shrink_interval_with_mask=False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size + num_frames_block = num_frames // causal_block_size + base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: infer_step_num = len(step_template) gen_block = base_num_frames_block min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -373,7 +373,7 @@ def generate_timestep_matrix( ) # to handle the counter in row works starting from 1 pre_row = torch.zeros(num_frames_block, dtype=torch.long) if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations + pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): new_row = torch.zeros(num_frames_block, dtype=torch.long) @@ -411,11 +411,11 @@ def generate_timestep_matrix( step_index = torch.stack(step_index, dim=0) step_matrix = torch.stack(step_matrix, dim=0) - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + if causal_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] return step_matrix, step_index, step_update_mask, valid_interval @@ -744,7 +744,7 @@ def __call__( prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] if prefix_video[0].shape[1] % causal_block_size != 0: truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") + logger.warning("The length of prefix video is truncated for the causal block size alignment.") prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] prefix_video_latent_length = prefix_video[0].shape[1] finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames From 00fdeb0a74dec081edc50ffd1c9c03963219e0b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 19:57:19 +0300 Subject: [PATCH 060/264] Update `_no_split_modules` in `SkyReelsV2Transformer3DModel` and adjust `causal_block_size` and `num_steps` references in `SkyReelsV2DiffusionForcingPipeline` to use the correct configuration properties for improved consistency and clarity. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d38ac6376dc6..e339e033dabf 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -374,7 +374,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] - _no_split_modules = ["WanTransformerBlock"] + _no_split_modules = ["SkyReelsV2TransformerBlock"] _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 7451aeadeaf7..32400c7dee93 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -614,7 +614,7 @@ def __call__( base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else num_latent_frames if causal_block_size is None: - causal_block_size = self.transformer.num_frame_per_block + causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] @@ -755,7 +755,7 @@ def __call__( num_inference_steps + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step ) - self.transformer.num_steps = num_steps + self.transformer.config.num_steps = num_steps else: base_num_frames_iter = base_num_frames From cf91fb44f0c8d6a7951740fdd9244d052a48e01a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 18 May 2025 20:11:54 +0300 Subject: [PATCH 061/264] Refactor type hint for `device` parameter in `_prepare_blockwise_causal_attn_mask` method to use `Union` for <=python 3.9. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index e339e033dabf..1e80f3dadcf7 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -575,7 +575,7 @@ def set_ar_attention(self, causal_block_size): @staticmethod def _prepare_blockwise_causal_attn_mask( - device: torch.device | str, num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 + device: Union[torch.device, str], num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 ) -> BlockMask: """ we will divide the token sequence into the following format [1 latent frame] [1 latent frame] ... [1 latent From 256fa6dd9501e5d79a0f4ced6433e52b4333e7a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 19 May 2025 13:41:07 +0300 Subject: [PATCH 062/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to streamline the setting of timesteps for sample schedulers in both short and long video generation paths, enhancing code clarity and maintainability. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 32400c7dee93..0746457a318f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -621,13 +621,13 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps sample_schedulers = [self.scheduler] for _ in range(num_latent_frames - 1): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size @@ -760,13 +760,13 @@ def __call__( base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps sample_schedulers = [deepcopy(self.scheduler)] for _ in range(base_num_frames_iter - 1): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, From a74252c7addfed4ca2071c6e7358f8c3bfe0733f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 19 May 2025 17:17:48 +0300 Subject: [PATCH 063/264] Add `flag_df` parameter to `SkyReelsV2Transformer3DModel` for improved timestep handling and refactor related logic to enhance clarity and maintainability. --- .../models/transformers/transformer_skyreels_v2.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 1e80f3dadcf7..30fc309cc52c 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -450,6 +450,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + flag_df: bool = False, fps: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -493,13 +494,6 @@ def forward( hidden_states = hidden_states.flatten(2).transpose(1, 2) - # TODO: check here - if timestep.dim() == 2: - b, f = timestep.shape - _flag_df = True - else: - _flag_df = False - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) @@ -518,14 +512,15 @@ def forward( fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() - if _flag_df: + if flag_df: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat( timestep.shape[1], 1, 1 ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) - if _flag_df: + if flag_df: + b, f = timestep.shape temb = temb.view(b, f, 1, 1, self.dim) timestep_proj = timestep_proj.view(b, f, 1, 1, 6, self.dim) temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) From 771fb05ad05752cdd3259eeeabc6d16c8628e887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 19 May 2025 17:20:01 +0300 Subject: [PATCH 064/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to enhance clarity and maintainability by updating type hints, adjusting timestep handling, and correcting parameter defaults for `shift` and `addnoise_condition`. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 67 ++++++++++--------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 0746457a318f..7a82745eaf4a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -344,13 +344,13 @@ def prepare_latents( def generate_timestep_matrix( self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - causal_block_size=1, - shrink_interval_with_mask=False, + num_frames: int, + step_template: torch.Tensor, + base_num_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: step_matrix, step_index = [], [] update_mask, valid_interval = [], [] @@ -401,7 +401,7 @@ def generate_timestep_matrix( update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): + for curr_mask in update_mask: if terminal_flag < num_frames_block and curr_mask[terminal_flag]: terminal_flag += 1 @@ -468,7 +468,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, overlap_history: Optional[int] = 17, - shift: float = 1.0, # TODO: check this + shift: float = 8.0, addnoise_condition: float = 20.0, base_num_frames: int = 97, ar_step: int = 5, @@ -535,10 +535,10 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `1.0`): + shift (`float`, *optional*, defaults to `8.0`): overlap_history (`int`, *optional*, defaults to `17`): Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20`): + addnoise_condition (`float`, *optional*, defaults to `20.0`): Improves consistency in long video generation base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) @@ -611,7 +611,11 @@ def __call__( prefix_video = None prefix_video_latent_length = 0 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else num_latent_frames + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) if causal_block_size is None: causal_block_size = self.transformer.config.num_frame_per_block @@ -652,33 +656,33 @@ def __call__( self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i + self._current_timestep = t update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + valid_interval_start, valid_interval_end = valid_interval[i] + timestep = t.expand(latents.shape[0])[:, valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor) + torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length]) * noise_factor ) - timestep[:, valid_interval_start:prefix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -688,6 +692,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -698,7 +703,7 @@ def __call__( if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( noise_pred[:, idx - valid_interval_start], - timestep_i[idx], + t[idx], latents[:, idx], return_dict=False, generator=generator, @@ -709,7 +714,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -798,21 +803,19 @@ def __call__( self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i + self._current_timestep = t update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + valid_interval_start, valid_interval_end = valid_interval[i] + timestep = t.expand(latents.shape[0])[:, valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( latent_model_input[:, valid_interval_start:prefix_video_latent_length] * (1.0 - noise_factor) @@ -821,14 +824,13 @@ def __call__( ) * noise_factor ) - timestep[:, valid_interval_start:prefix_video_latent_length] = ( - timestep_for_noised_condition - ) + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -838,6 +840,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -847,7 +850,7 @@ def __call__( if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( noise_pred[:, idx - valid_interval_start], - timestep_i[idx], + t[idx], latents[:, idx], return_dict=False, generator=generator, @@ -858,7 +861,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) From bccad5504df08e30f474c35ae028a6e1756d67d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 14:43:53 +0300 Subject: [PATCH 065/264] Add script for converting SkyReelsV2 models to Diffusers format --- scripts/convert_skyreelsv2_to_diffusers.py | 469 +++++++++++++++++++++ 1 file changed, 469 insertions(+) create mode 100644 scripts/convert_skyreelsv2_to_diffusers.py diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py new file mode 100644 index 000000000000..d5410b7f7b2e --- /dev/null +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -0,0 +1,469 @@ +import argparse +import pathlib +from typing import Any, Dict + +import torch +from accelerate import init_empty_weights +from huggingface_hub import hf_hub_download, snapshot_download +from safetensors.torch import load_file +from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel + +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2ImageToVideoPipeline, + SkyReelsV2Pipeline, + SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2DiffusionForcingImageToImagePipeline, + SkyReelsV2DiffusionForcingVideoToVideoPipeline, + SkyReelsV2Transformer3DModel, +) + + +TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # For the I2V model + "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # for the FLF2V model + "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + +def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def load_sharded_safetensors(dir: pathlib.Path): + file_paths = list(dir.glob("model*.safetensors")) + state_dict = {} + for path in file_paths: + state_dict.update(load_file(path)) + return state_dict + + +def get_transformer_config(model_type: str) -> Dict[str, Any]: + if model_type == "SkyReels-V2-DF-1.3B-540P": + config = { + "model_id": "Skywork/SkyReels-V2-DF-1.3B-540P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 30, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm", + "text_dim": 4096, + }, + } + elif model_type == "SkyReelsV2-T2V-14B": + config = { + "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReelsV2-I2V-14B-480p": + config = { + "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReelsV2-I2V-14B-720p": + config = { + "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "Wan-FLF2V-14B-720P": + config = { + "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder + "diffusers_config": { + "image_dim": 1280, + "added_kv_proj_dim": 5120, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 36, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "rope_max_seq_len": 1024, + "pos_embed_seq_len": 257 * 2, + }, + } + return config + + +def convert_transformer(model_type: str): + config = get_transformer_config(model_type) + diffusers_config = config["diffusers_config"] + model_id = config["model_id"] + model_dir = hf_hub_download(model_id, "model.safetensors") + + original_state_dict = load_sharded_safetensors(model_dir) + + with init_empty_weights(): + transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True, assign=True) + return transformer + + +def convert_vae(): + vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth") + old_state_dict = torch.load(vae_ckpt_path, weights_only=True) + new_state_dict = {} + + # Create mappings for specific components + middle_key_mapping = { + # Encoder middle block + "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma", + "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias", + "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight", + "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma", + "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias", + "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight", + "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma", + "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias", + "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight", + "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma", + "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias", + "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight", + # Decoder middle block + "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma", + "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias", + "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight", + "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma", + "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias", + "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight", + "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma", + "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias", + "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight", + "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma", + "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias", + "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight", + } + + # Create a mapping for attention blocks + attention_mapping = { + # Encoder middle attention + "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma", + "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight", + "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias", + "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight", + "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias", + # Decoder middle attention + "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma", + "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight", + "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias", + "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight", + "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias", + } + + # Create a mapping for the head components + head_mapping = { + # Encoder head + "encoder.head.0.gamma": "encoder.norm_out.gamma", + "encoder.head.2.bias": "encoder.conv_out.bias", + "encoder.head.2.weight": "encoder.conv_out.weight", + # Decoder head + "decoder.head.0.gamma": "decoder.norm_out.gamma", + "decoder.head.2.bias": "decoder.conv_out.bias", + "decoder.head.2.weight": "decoder.conv_out.weight", + } + + # Create a mapping for the quant components + quant_mapping = { + "conv1.weight": "quant_conv.weight", + "conv1.bias": "quant_conv.bias", + "conv2.weight": "post_quant_conv.weight", + "conv2.bias": "post_quant_conv.bias", + } + + # Process each key in the state dict + for key, value in old_state_dict.items(): + # Handle middle block keys using the mapping + if key in middle_key_mapping: + new_key = middle_key_mapping[key] + new_state_dict[new_key] = value + # Handle attention blocks using the mapping + elif key in attention_mapping: + new_key = attention_mapping[key] + new_state_dict[new_key] = value + # Handle head keys using the mapping + elif key in head_mapping: + new_key = head_mapping[key] + new_state_dict[new_key] = value + # Handle quant keys using the mapping + elif key in quant_mapping: + new_key = quant_mapping[key] + new_state_dict[new_key] = value + # Handle encoder conv1 + elif key == "encoder.conv1.weight": + new_state_dict["encoder.conv_in.weight"] = value + elif key == "encoder.conv1.bias": + new_state_dict["encoder.conv_in.bias"] = value + # Handle decoder conv1 + elif key == "decoder.conv1.weight": + new_state_dict["decoder.conv_in.weight"] = value + elif key == "decoder.conv1.bias": + new_state_dict["decoder.conv_in.bias"] = value + # Handle encoder downsamples + elif key.startswith("encoder.downsamples."): + # Convert to down_blocks + new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") + + # Convert residual block naming but keep the original structure + if ".residual.0.gamma" in new_key: + new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") + elif ".residual.2.bias" in new_key: + new_key = new_key.replace(".residual.2.bias", ".conv1.bias") + elif ".residual.2.weight" in new_key: + new_key = new_key.replace(".residual.2.weight", ".conv1.weight") + elif ".residual.3.gamma" in new_key: + new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma") + elif ".residual.6.bias" in new_key: + new_key = new_key.replace(".residual.6.bias", ".conv2.bias") + elif ".residual.6.weight" in new_key: + new_key = new_key.replace(".residual.6.weight", ".conv2.weight") + elif ".shortcut.bias" in new_key: + new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") + elif ".shortcut.weight" in new_key: + new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") + + new_state_dict[new_key] = value + + # Handle decoder upsamples + elif key.startswith("decoder.upsamples."): + # Convert to up_blocks + parts = key.split(".") + block_idx = int(parts[2]) + + # Group residual blocks + if "residual" in key: + if block_idx in [0, 1, 2]: + new_block_idx = 0 + resnet_idx = block_idx + elif block_idx in [4, 5, 6]: + new_block_idx = 1 + resnet_idx = block_idx - 4 + elif block_idx in [8, 9, 10]: + new_block_idx = 2 + resnet_idx = block_idx - 8 + elif block_idx in [12, 13, 14]: + new_block_idx = 3 + resnet_idx = block_idx - 12 + else: + # Keep as is for other blocks + new_state_dict[key] = value + continue + + # Convert residual block naming + if ".residual.0.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma" + elif ".residual.2.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias" + elif ".residual.2.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight" + elif ".residual.3.gamma" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma" + elif ".residual.6.bias" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias" + elif ".residual.6.weight" in key: + new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight" + else: + new_key = key + + new_state_dict[new_key] = value + + # Handle shortcut connections + elif ".shortcut." in key: + if block_idx == 4: + new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.") + new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_key = new_key.replace(".shortcut.", ".conv_shortcut.") + + new_state_dict[new_key] = value + + # Handle upsamplers + elif ".resample." in key or ".time_conv." in key: + if block_idx == 3: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0") + elif block_idx == 7: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0") + elif block_idx == 11: + new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0") + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + + new_state_dict[new_key] = value + else: + new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") + new_state_dict[new_key] = value + else: + # Keep other keys unchanged + new_state_dict[key] = value + + with init_empty_weights(): + vae = AutoencoderKLWan() + vae.load_state_dict(new_state_dict, strict=True, assign=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_type", type=str, default=None) + parser.add_argument("--output_path", type=str, required=True) + parser.add_argument("--dtype", default="fp32") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + + transformer = convert_transformer(args.model_type).to(dtype=dtype) + vae = convert_vae() + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + scheduler = FlowMatchUniPCMultistepScheduler( + prediction_type="flow_prediction", num_train_timesteps=1000, + ) + + if "I2V" in args.model_type: + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 + ) + image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + pipe = SkyReelsV2ImageToVideoPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + else: + pipe = SkyReelsV2DiffusionForcingPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", + push_to_hub=True, + repo_id="tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers-2", + ) From 59c1e88c6aacfde0ce6903ab43aa9e7cd4b01f99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 14:50:08 +0300 Subject: [PATCH 066/264] down --- scripts/convert_skyreelsv2_to_diffusers.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index d5410b7f7b2e..2c94268ff26b 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -11,11 +11,7 @@ from diffusers import ( AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, - SkyReelsV2ImageToVideoPipeline, - SkyReelsV2Pipeline, SkyReelsV2DiffusionForcingPipeline, - SkyReelsV2DiffusionForcingImageToImagePipeline, - SkyReelsV2DiffusionForcingVideoToVideoPipeline, SkyReelsV2Transformer3DModel, ) From 02f038d18c86d29928e37cbb73430c3a1796ec59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 14:58:21 +0300 Subject: [PATCH 067/264] Update documentation in `SkyReelsV2DiffusionForcingPipeline` to clarify examples and return values. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 7a82745eaf4a..2eaee6d073bd 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -548,7 +548,9 @@ def __call__( Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): - Examples: Returns: + Examples: + + Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s From 32ca01a530fff31666dce244408c3638ff6d5d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 15:01:05 +0300 Subject: [PATCH 068/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 2c94268ff26b..123486216b08 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -11,10 +11,11 @@ from diffusers import ( AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, - SkyReelsV2DiffusionForcingPipeline, SkyReelsV2Transformer3DModel, ) +from diffusers.pipelines import SkyReelsV2DiffusionForcingPipeline + TRANSFORMER_KEYS_RENAME_DICT = { "time_embedding.0": "condition_embedder.time_embedder.linear_1", From 02ffe0c519ab7ea2450ca6d7b94d9e24719b77d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 15:03:42 +0300 Subject: [PATCH 069/264] Refactor model directory path handling in `convert_transformer` function to use `pathlib.Path` for improved compatibility. --- scripts/convert_skyreelsv2_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 123486216b08..08b3654baac9 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -182,7 +182,7 @@ def convert_transformer(model_type: str): config = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] model_id = config["model_id"] - model_dir = hf_hub_download(model_id, "model.safetensors") + model_dir = pathlib.Path(hf_hub_download(model_id, "model.safetensors")) original_state_dict = load_sharded_safetensors(model_dir) From a2156775d25442b19ca613d545e179cd1e389028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 15:21:00 +0300 Subject: [PATCH 070/264] fix "inject_sample_info": true, --- scripts/convert_skyreelsv2_to_diffusers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 08b3654baac9..1f40a10426db 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -87,6 +87,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "freq_dim": 256, "in_channels": 16, "num_attention_heads": 12, + "inject_sample_info": True, "num_layers": 30, "out_channels": 16, "patch_size": [1, 2, 2], From 1e4c50172d27089a1fd07c56d3e42a19a6af7309 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 16:24:46 +0300 Subject: [PATCH 071/264] temp fix --- scripts/convert_skyreelsv2_to_diffusers.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 1f40a10426db..73dc2e1e81be 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -67,10 +67,9 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) - def load_sharded_safetensors(dir: pathlib.Path): - file_paths = list(dir.glob("model*.safetensors")) + #file_paths = list(dir.glob("model*.safetensors")) state_dict = {} - for path in file_paths: - state_dict.update(load_file(path)) + state_dict.update(load_file(dir)) return state_dict @@ -183,7 +182,7 @@ def convert_transformer(model_type: str): config = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] model_id = config["model_id"] - model_dir = pathlib.Path(hf_hub_download(model_id, "model.safetensors")) + model_dir = hf_hub_download(model_id, "model.safetensors") original_state_dict = load_sharded_safetensors(model_dir) From 322ce0cc5c676bd82919a8ff1b21ac9cc505b66c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 16:29:17 +0300 Subject: [PATCH 072/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 73dc2e1e81be..c3eec00ef7bd 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -90,7 +90,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "num_layers": 30, "out_channels": 16, "patch_size": [1, 2, 2], - "qk_norm": "rms_norm", + "qk_norm": "rms_norm_across_heads", "text_dim": 4096, }, } From b7d54d6414eeec39c9209370a5431a2fabf99b86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 16:31:55 +0300 Subject: [PATCH 073/264] fix `qk_norm` --- .../models/transformers/transformer_skyreels_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 30fc309cc52c..a3807fdbd176 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -243,7 +243,7 @@ def __init__( dim: int, ffn_dim: int, num_heads: int, - qk_norm: str = "rms_norm", + qk_norm: str = "rms_norm_across_heads", cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: Optional[int] = None, @@ -356,7 +356,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr Window size for local attention (-1 indicates global attention). cross_attn_norm (`bool`, defaults to `True`): Enable cross-attention normalization. - qk_norm (`str`, *optional*, defaults to `"rms_norm"`): + qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`): Enable query/key normalization. eps (`float`, defaults to `1e-6`): Epsilon value for normalization layers. @@ -391,7 +391,7 @@ def __init__( ffn_dim: int = 8192, num_layers: int = 32, cross_attn_norm: bool = True, - qk_norm: Optional[str] = "rms_norm", + qk_norm: Optional[str] = "rms_norm_across_heads", eps: float = 1e-6, image_dim: Optional[int] = None, added_kv_proj_dim: Optional[int] = None, From be77ad8d15bb348a3b87f86b0f674fde4c0bc1ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 16:56:01 +0300 Subject: [PATCH 074/264] Refactor `convert_skyreelsv2_to_diffusers.py` to use `SkyreelsV2ImageToVideoPipeline` and set text encoder and VAE to None for improved memory. --- scripts/convert_skyreelsv2_to_diffusers.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index c3eec00ef7bd..326f572d47a2 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -15,6 +15,7 @@ ) from diffusers.pipelines import SkyReelsV2DiffusionForcingPipeline +from diffusers.utils.dummy_torch_and_transformers_objects import SkyreelsV2ImageToVideoPipeline TRANSFORMER_KEYS_RENAME_DICT = { @@ -430,8 +431,8 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] transformer = convert_transformer(args.model_type).to(dtype=dtype) - vae = convert_vae() - text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + #vae = convert_vae() + #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") scheduler = FlowMatchUniPCMultistepScheduler( prediction_type="flow_prediction", num_train_timesteps=1000, @@ -442,11 +443,11 @@ def get_args(): "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 ) image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - pipe = SkyReelsV2ImageToVideoPipeline( + pipe = SkyreelsV2ImageToVideoPipeline( transformer=transformer, - text_encoder=text_encoder, + text_encoder=None, tokenizer=tokenizer, - vae=vae, + vae=None, scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, @@ -454,9 +455,9 @@ def get_args(): else: pipe = SkyReelsV2DiffusionForcingPipeline( transformer=transformer, - text_encoder=text_encoder, + text_encoder=None, tokenizer=tokenizer, - vae=vae, + vae=None, scheduler=scheduler, ) From 6f8ffb2bd60b6a4853702c547cbe182c651be90c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 17:03:44 +0300 Subject: [PATCH 075/264] for vae --- scripts/convert_skyreelsv2_to_diffusers.py | 24 +++++++++++----------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 326f572d47a2..71f3c01ca076 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -430,13 +430,13 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] - transformer = convert_transformer(args.model_type).to(dtype=dtype) - #vae = convert_vae() + #transformer = convert_transformer(args.model_type).to(dtype=dtype) + vae = convert_vae() #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") - tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") - scheduler = FlowMatchUniPCMultistepScheduler( - prediction_type="flow_prediction", num_train_timesteps=1000, - ) + #tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + #scheduler = FlowMatchUniPCMultistepScheduler( + # prediction_type="flow_prediction", num_train_timesteps=1000, + #) if "I2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( @@ -446,19 +446,19 @@ def get_args(): pipe = SkyreelsV2ImageToVideoPipeline( transformer=transformer, text_encoder=None, - tokenizer=tokenizer, + tokenizer=None, vae=None, - scheduler=scheduler, + scheduler=None, image_encoder=image_encoder, image_processor=image_processor, ) else: pipe = SkyReelsV2DiffusionForcingPipeline( - transformer=transformer, + transformer=None, text_encoder=None, - tokenizer=tokenizer, - vae=None, - scheduler=scheduler, + tokenizer=None, + vae=vae, + scheduler=None, ) pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", From 4576f6e4a96a308d4f66ed711994ce1d44d73153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 17:06:31 +0300 Subject: [PATCH 076/264] for t5 --- scripts/convert_skyreelsv2_to_diffusers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 71f3c01ca076..f4e6363c6f5d 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -431,8 +431,8 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] #transformer = convert_transformer(args.model_type).to(dtype=dtype) - vae = convert_vae() - #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + #vae = convert_vae() + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") #tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") #scheduler = FlowMatchUniPCMultistepScheduler( # prediction_type="flow_prediction", num_train_timesteps=1000, @@ -455,9 +455,9 @@ def get_args(): else: pipe = SkyReelsV2DiffusionForcingPipeline( transformer=None, - text_encoder=None, + text_encoder=text_encoder, tokenizer=None, - vae=vae, + vae=None, scheduler=None, ) From 10174cac51043175ae8d19ba579e909b5a87db05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 17:59:40 +0300 Subject: [PATCH 077/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 2 +- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index f4e6363c6f5d..b4f939aabed6 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -460,7 +460,7 @@ def get_args(): vae=None, scheduler=None, ) - + # pipe.push_to_hub pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=True, repo_id="tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers-2", diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2eaee6d073bd..2a4553d07e14 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -740,10 +740,10 @@ def __call__( latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) + .to(device, torch.float32) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype + device, torch.float32 ) for i in range(n_iter): if video is not None: From 9223f2d07a660f2b1afa112f55f74a2d12ad7ee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 18:06:40 +0300 Subject: [PATCH 078/264] temp fix --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2a4553d07e14..4a14db8cf4c3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -812,7 +812,7 @@ def __call__( self._current_timestep = t update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] - timestep = t.expand(latents.shape[0])[:, valid_interval_start:valid_interval_end].clone() + timestep = t[None, :][:, valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) From a1aadd3a61bba7cc47c458113ce056c56ea1c413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 20 May 2025 18:10:07 +0300 Subject: [PATCH 079/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 4a14db8cf4c3..65eac6d3bbcb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -812,7 +812,7 @@ def __call__( self._current_timestep = t update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] - timestep = t[None, :][:, valid_interval_start:valid_interval_end].clone() + timestep = t[valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) @@ -826,7 +826,7 @@ def __call__( ) * noise_factor ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + timestep[valid_interval_start:prefix_video_latent_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, From f369cc4f544632bda4d138282a748b47bf53b3b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 10:27:04 +0300 Subject: [PATCH 080/264] Remove assertion for 1D timesteps in `get_timestep_embedding` function temporarily --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c25e9997e3fb..b928ca626765 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -51,7 +51,7 @@ def get_timestep_embedding( Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ - assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + #assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( From eb3237609fabef895a960a07b3840b97fea3736d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 10:27:22 +0300 Subject: [PATCH 081/264] Refactor timestep handling in `SkyReelsV2DiffusionForcingPipeline` to ensure correct scheduler initialization and update mask usage for improved video generation stability. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 65eac6d3bbcb..fb57cfe36c22 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -628,11 +628,11 @@ def __call__( # Short video generation # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [self.scheduler] + sample_schedulers[0].set_timesteps(num_inference_steps, device=device, shift=shift) for _ in range(num_latent_frames - 1): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( @@ -663,13 +663,12 @@ def __call__( continue self._current_timestep = t - - update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] - timestep = t.expand(latents.shape[0])[:, valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( @@ -701,6 +700,7 @@ def __call__( )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( @@ -768,12 +768,12 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [deepcopy(self.scheduler)] + sample_schedulers[0].set_timesteps(num_inference_steps, device=device, shift=shift) for _ in range(base_num_frames_iter - 1): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - timesteps = self.scheduler.timesteps + timesteps = sample_schedulers[0].timesteps sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, @@ -810,12 +810,12 @@ def __call__( continue self._current_timestep = t - update_mask_i = step_update_mask[i] valid_interval_start, valid_interval_end = valid_interval[i] - timestep = t[valid_interval_start:valid_interval_end].clone() latent_model_input = ( latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( @@ -848,6 +848,8 @@ def __call__( return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): latents[:, idx] = sample_schedulers[idx].step( From 671b37e9c474de6926f151f8ac443840d20e0d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 11:05:00 +0300 Subject: [PATCH 082/264] Enhance `get_timestep_embedding` to support 2D tensor inputs, allowing for batch processing of multiple frames. Update documentation to reflect new input shape and ensure correct reshaping of output embeddings. --- src/diffusers/models/embeddings.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b928ca626765..559a9493f9e7 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -38,6 +38,7 @@ def get_timestep_embedding( Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. + Can also be a 2-D Tensor of shape (batch_size, num_frames). embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): @@ -49,18 +50,25 @@ def get_timestep_embedding( max_period (int): Controls the maximum frequency of the embeddings Returns - torch.Tensor: an [N x dim] Tensor of positional embeddings. + torch.Tensor: an [N x dim] Tensor of positional embeddings. If input was 2D, shape is [B x F x dim]. """ - #assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + original_shape = timesteps.shape + if len(original_shape) == 2: + # timesteps is (B, F_v), flatten to (B * F_v) + timesteps_flat = timesteps.reshape(-1) + elif len(original_shape) == 1: + timesteps_flat = timesteps + else: + raise ValueError(f"Timesteps should be 1D or 2D, but got shape {original_shape}") half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + start=0, end=half_dim, dtype=torch.float32, device=timesteps_flat.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) - emb = timesteps[:, None].float() * emb[None, :] + emb = timesteps_flat[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb @@ -75,6 +83,11 @@ def get_timestep_embedding( # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + + if len(original_shape) == 2: + # Reshape back to (B, F_v, embedding_dim) + emb = emb.reshape(original_shape[0], original_shape[1], embedding_dim) + return emb From 6f8bf305e07b8509e9a5e53b2039539456131224 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 11:12:53 +0300 Subject: [PATCH 083/264] Fix unflattening of timestep projection in `SkyReelsV2Transformer3DModel` to use last dimension for improved tensor handling. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index a3807fdbd176..69efb18bd034 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -497,7 +497,7 @@ def forward( temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - timestep_proj = timestep_proj.unflatten(1, (6, -1)) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) From c71d3aa86d7768027aef81d0c76fc1764c519180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 11:27:36 +0300 Subject: [PATCH 084/264] Update dtype handling in `SkyReelsV2Transformer3DModel` to ensure consistent tensor types for `timestep_proj` and `fps_projection`, enhancing compatibility with hidden states. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 69efb18bd034..41e1613b7a93 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -512,12 +512,15 @@ def forward( fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() + timestep_proj.to(fps_emb.dtype) + self.fps_projection.to(fps_emb.dtype) if flag_df: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat( timestep.shape[1], 1, 1 ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) + timestep_proj.to(hidden_states.dtype) if flag_df: b, f = timestep.shape From 1afa337d011963eeec613d9dc778c1bbac80d63f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 11:33:07 +0300 Subject: [PATCH 085/264] Refactor tensor reshaping in `SkyReelsV2Transformer3DModel` to utilize dynamic dimensions for `temb` and `timestep_proj`, improving flexibility in handling varying input sizes. --- .../models/transformers/transformer_skyreels_v2.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 41e1613b7a93..e26785af4d0f 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -515,17 +515,17 @@ def forward( timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) if flag_df: - timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat( + timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat( timestep.shape[1], 1, 1 ) else: - timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)) + timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) timestep_proj.to(hidden_states.dtype) if flag_df: b, f = timestep.shape - temb = temb.view(b, f, 1, 1, self.dim) - timestep_proj = timestep_proj.view(b, f, 1, 1, 6, self.dim) + temb = temb.view(b, f, 1, 1, -1) + timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() From c74675cc27a29155950db64d09cdc9caf321828e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 12:49:40 +0300 Subject: [PATCH 086/264] Refactor timestep preparation in `SkyReelsV2DiffusionForcingPipeline` to streamline scheduler initialization and enhance video generation by ensuring correct handling of timesteps across multiple latent frames. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index fb57cfe36c22..a4e5adaa8660 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -610,6 +610,10 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + prefix_video = None prefix_video_latent_length = 0 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -627,13 +631,11 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [self.scheduler] - sample_schedulers[0].set_timesteps(num_inference_steps, device=device, shift=shift) - for _ in range(num_latent_frames - 1): + sample_schedulers = [] + for _ in range(num_latent_frames): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - timesteps = self.scheduler.timesteps sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size @@ -767,13 +769,11 @@ def __call__( base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [deepcopy(self.scheduler)] - sample_schedulers[0].set_timesteps(num_inference_steps, device=device, shift=shift) - for _ in range(base_num_frames_iter - 1): + sample_schedulers = [] + for _ in range(base_num_frames_iter): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - timesteps = sample_schedulers[0].timesteps sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, From 602cff7d8d3c175b864201c34e68926f13a912be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 14:06:11 +0300 Subject: [PATCH 087/264] fix: multi-dimentional indexing --- .../pipeline_skyreels_v2_diffusion_forcing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a4e5adaa8660..db9e46ec788b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -705,10 +705,10 @@ def __call__( update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], t[idx], - latents[:, idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] @@ -852,10 +852,10 @@ def __call__( update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], t[idx], - latents[:, idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] From 237e468dff0da8ac1d48fa9f71e1ecbf723bbd11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 15:26:50 +0300 Subject: [PATCH 088/264] Comment out tensor unsqueezing in `SkyReelsV2DiffusionForcingPipeline` to prevent unintended dimensionality changes during video generation, ensuring stability in processing latents. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index db9e46ec788b..08a310a4b60c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -733,7 +733,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - latents = latents.unsqueeze(0) + #latents = latents.unsqueeze(0) else: # Long video generation overlap_history_frames = (overlap_history - 1) // 4 + 1 @@ -882,7 +882,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - latents = latents.unsqueeze(0) + #latents = latents.unsqueeze(0) if not output_type == "latent": latents = latents.to(self.vae.dtype) latents = latents / latents_std + latents_mean From 40c456d6a34233a81722cfa4ae184f1b2be9a8a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 15:39:33 +0300 Subject: [PATCH 089/264] Update dtype handling in `SkyReelsV2DiffusionForcingPipeline` to use dynamic tensor types for latents, ensuring compatibility and stability during video generation. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 08a310a4b60c..8ab0a54ced76 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -742,10 +742,10 @@ def __call__( latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, torch.float32) + .to(device, latents.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, torch.float32 + device, latents.dtype ) for i in range(n_iter): if video is not None: From 9ed88da55d5d5870793d0085e2c3ca0a439eed6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 15:41:50 +0300 Subject: [PATCH 090/264] fix dype --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8ab0a54ced76..1a370134648f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -742,10 +742,10 @@ def __call__( latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, latents.dtype) + .to(device, transformer_dtype.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, latents.dtype + device, transformer_dtype.dtype ) for i in range(n_iter): if video is not None: From 6a3c7bf9a7373ad25a2fb2064c5e746e8358dc80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 15:43:02 +0300 Subject: [PATCH 091/264] fix --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 1a370134648f..fd2ba52ac43e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -742,10 +742,10 @@ def __call__( latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, transformer_dtype.dtype) + .to(device, transformer_dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, transformer_dtype.dtype + device, transformer_dtype ) for i in range(n_iter): if video is not None: From 5652aa01cc9c8511f2bda813cd0a5dbbd9434f03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 16:08:01 +0300 Subject: [PATCH 092/264] Refactor sample scheduler initialization in `SkyReelsV2DiffusionForcingPipeline` to use `FlowMatchUniPCMultistepScheduler` from configuration, enhancing flexibility and maintainability. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index fd2ba52ac43e..6df3882610df 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -633,7 +633,7 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): - sample_scheduler = deepcopy(self.scheduler) + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * num_latent_frames @@ -771,7 +771,7 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(base_num_frames_iter): - sample_scheduler = deepcopy(self.scheduler) + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter From e529fea3c36ee324da38437a9a2fa477b7cb32f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 16:22:58 +0300 Subject: [PATCH 093/264] Adds shift parameter to scheduler timestep setting Passes the `shift` parameter to `self.scheduler.set_timesteps`. This allows for adjusting the timestep generation. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 6df3882610df..b5d80f3f945b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -611,7 +611,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps prefix_video = None From b3ffecacd1739b54dd14febb2a9ba883b5fced5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 17:24:34 +0300 Subject: [PATCH 094/264] Fix slicing of latents in `SkyReelsV2DiffusionForcingPipeline` to ensure correct tensor dimensions during processing. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index b5d80f3f945b..a5f6e11cc5e9 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -667,7 +667,7 @@ def __call__( self._current_timestep = t valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() @@ -759,12 +759,6 @@ def __call__( finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - if ar_step > 0: - num_steps = ( - num_inference_steps - + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step - ) - self.transformer.config.num_steps = num_steps else: base_num_frames_iter = base_num_frames @@ -812,7 +806,7 @@ def __call__( self._current_timestep = t valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() From 4479afc19143ff6e4b0390e1cbe15af61deeb2ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 21 May 2025 18:02:34 +0300 Subject: [PATCH 095/264] Fix tensor slicing in `SkyReelsV2DiffusionForcingPipeline` to ensure correct dimensions for latent model input and noise application. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a5f6e11cc5e9..4317d62ae535 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -673,10 +673,10 @@ def __call__( if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition - latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:prefix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, valid_interval_start:prefix_video_latent_length]) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) * noise_factor ) timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition @@ -812,15 +812,15 @@ def __call__( if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: noise_factor = 0.001 * addnoise_condition - latent_model_input[:, valid_interval_start:prefix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:prefix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[:, valid_interval_start:prefix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] ) * noise_factor ) - timestep[valid_interval_start:prefix_video_latent_length] = addnoise_condition + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, From e4f6743449c63f3839d6432157166f78c6c7f6ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 23 May 2025 14:57:11 +0300 Subject: [PATCH 096/264] Update progress bar total in `SkyReelsV2DiffusionForcingPipeline` to reflect the actual length of the step matrix, ensuring accurate progress tracking during inference. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 4317d62ae535..a5a8f65fa075 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -659,7 +659,7 @@ def __call__( num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -798,7 +798,7 @@ def __call__( num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue From 74204461e9f4470c5e7b7be5207e66fd146f5a04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 23 May 2025 15:49:23 +0300 Subject: [PATCH 097/264] Refactor error handling and tensor processing in `SkyReelsV2DiffusionForcingPipeline` to improve robustness during inference. Added try-except blocks for better error reporting and streamlined tensor operations for noise application and latent updates. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 222 +++++++++--------- 1 file changed, 117 insertions(+), 105 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a5a8f65fa075..2434b53d3721 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -664,65 +664,72 @@ def __call__( if self.interrupt: continue - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() - ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) - * noise_factor + try: + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + * noise_factor + ) + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + flag_df=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + except Exception as e: + print(f"Error at iteration {i}, timestep value (can be a tensor for multiple frames): {self._current_timestep}") + print(f"Exception: {e}") + import traceback + traceback.print_exc() + raise # call the callback, if provided if i == len(step_matrix) - 1 or ( @@ -803,78 +810,83 @@ def __call__( if self.interrupt: continue - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() - ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + try: + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like( + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + * noise_factor ) - * noise_factor - ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + flag_df=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, return_dict=False, - generator=generator, )[0] - sample_schedulers_counter[idx] += 1 - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - - # call the callback, if provided - if i == len(step_matrix) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], + return_dict=False, + generator=generator, + )[0] + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) - if XLA_AVAILABLE: - xm.mark_step() + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + except Exception as e: + print(f"Error at iteration {i}, timestep value (can be a tensor for multiple frames): {self._current_timestep}") + print(f"Exception: {e}") + import traceback + traceback.print_exc() + raise #latents = latents.unsqueeze(0) if not output_type == "latent": From 2d59ebd6d8fd073e52e014dd849d9e8210a4ab24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 23 May 2025 16:16:19 +0300 Subject: [PATCH 098/264] Refactor tensor processing and noise application in `SkyReelsV2DiffusionForcingPipeline` to enhance clarity and efficiency. Updated progress bar total to match the number of inference steps, ensuring accurate tracking. Streamlined the handling of latent model inputs and noise predictions for improved performance during inference. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 233 +++++++++--------- 1 file changed, 110 insertions(+), 123 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2434b53d3721..fbcfd6dc3963 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -659,77 +659,70 @@ def __call__( num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=len(step_matrix)) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue - try: - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + * noise_factor ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) - * noise_factor - ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition - - noise_pred = self.transformer( + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + flag_df=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, + generator=generator, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 + sample_schedulers_counter[idx] += 1 - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - except Exception as e: - print(f"Error at iteration {i}, timestep value (can be a tensor for multiple frames): {self._current_timestep}") - print(f"Exception: {e}") - import traceback - traceback.print_exc() - raise + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(step_matrix) - 1 or ( @@ -805,88 +798,82 @@ def __call__( num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=len(step_matrix)) as progress_bar: + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue - try: - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] + latent_model_input = ( + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + ) + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + noise_factor = 0.001 * addnoise_condition + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + * (1.0 - noise_factor) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + * noise_factor ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) - * noise_factor - ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition - - noise_pred = self.transformer( + timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + flag_df=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] + for idx in range(valid_interval_start, valid_interval_end): + if update_mask_i[idx].item(): + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, + generator=generator, )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - - # call the callback, if provided - if i == len(step_matrix) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - except Exception as e: - print(f"Error at iteration {i}, timestep value (can be a tensor for multiple frames): {self._current_timestep}") - print(f"Exception: {e}") - import traceback - traceback.print_exc() - raise + sample_schedulers_counter[idx] += 1 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + # call the callback, if provided + if i == len(step_matrix) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + #latents = latents.unsqueeze(0) if not output_type == "latent": From 8af4a9f7975261538354fc383ee089ed648bf469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 23 May 2025 22:28:45 +0300 Subject: [PATCH 099/264] Refactor variable naming and tensor handling in `SkyReelsV2DiffusionForcingPipeline` to improve clarity and maintainability. Updated prefix video latent length variables for consistency and corrected tensor slicing to ensure proper dimensions during processing. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index fbcfd6dc3963..e4d1e745aaba 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -614,8 +614,7 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - prefix_video = None - prefix_video_latent_length = 0 + prefix_video_latents_length = 0 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = ( (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -638,7 +637,7 @@ def __call__( sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latent_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) # 5. Prepare latent variables @@ -671,15 +670,15 @@ def __call__( ) timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :]) * noise_factor ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, @@ -736,26 +735,29 @@ def __call__( #latents = latents.unsqueeze(0) else: # Long video generation - overlap_history_frames = (overlap_history - 1) // 4 + 1 + overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None + prefix_video_latents = None latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, transformer_dtype) + .to(device, self.vae.dtype) ) latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, transformer_dtype + device, self.vae.dtype ) for i in range(n_iter): if video is not None: - prefix_video = video[:, -overlap_history:].to(prompt_embeds.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - logger.warning("The length of prefix video is truncated for the causal block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - prefix_video_latent_length = prefix_video[0].shape[1] + prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist + prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning("The length of prefix video latents is truncated for the causal block size alignment.") + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) @@ -774,7 +776,7 @@ def __call__( timesteps, base_num_frames_iter, ar_step, - prefix_video_latent_length, + prefix_video_latents_length, causal_block_size, ) @@ -791,8 +793,8 @@ def __call__( generator, latents, ) - if prefix_video is not None: - latents[:, :prefix_video_latent_length] = prefix_video[0].to(transformer_dtype) + if prefix_video_latents is not None: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order @@ -810,15 +812,15 @@ def __call__( ) timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - if addnoise_condition > 0 and valid_interval_start < prefix_video_latent_length: + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latent_length, :, :]) + + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :]) * noise_factor ) - timestep[:, valid_interval_start:prefix_video_latent_length] = addnoise_condition + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, @@ -874,16 +876,14 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - - #latents = latents.unsqueeze(0) if not output_type == "latent": latents = latents.to(self.vae.dtype) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents, return_dict=False)[0] if video is None: - video = videos # c, f, h, w + video = videos else: - video = torch.cat([video, videos[:, overlap_history:]], 1) # c, f, h, w + video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) else: video = latents From 57a2bf9d53b348e4f50b2f5ef479bc6f8b6b78a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 23 May 2025 22:31:52 +0300 Subject: [PATCH 100/264] style --- scripts/convert_skyreelsv2_to_diffusers.py | 18 ++++++++++-------- src/diffusers/models/embeddings.py | 4 ++-- .../pipeline_skyreels_v2_diffusion_forcing.py | 14 +++++++++----- .../scheduling_flow_match_unipc_multistep.py | 2 +- 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index b4f939aabed6..aefdf485fee4 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -4,16 +4,15 @@ import torch from accelerate import init_empty_weights -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import AutoProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel from diffusers import ( AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2Transformer3DModel, ) - from diffusers.pipelines import SkyReelsV2DiffusionForcingPipeline from diffusers.utils.dummy_torch_and_transformers_objects import SkyreelsV2ImageToVideoPipeline @@ -68,7 +67,7 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) - def load_sharded_safetensors(dir: pathlib.Path): - #file_paths = list(dir.glob("model*.safetensors")) + # file_paths = list(dir.glob("model*.safetensors")) state_dict = {} state_dict.update(load_file(dir)) return state_dict @@ -461,7 +460,10 @@ def get_args(): scheduler=None, ) # pipe.push_to_hub - pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", - push_to_hub=True, - repo_id="tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers-2", - ) + pipe.save_pretrained( + args.output_path, + safe_serialization=True, + max_shard_size="5GB", + push_to_hub=True, + repo_id="tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers-2", + ) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 559a9493f9e7..39ae3ebe5d37 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -37,8 +37,8 @@ def get_timestep_embedding( Args timesteps (torch.Tensor): - a 1-D Tensor of N indices, one per batch element. These may be fractional. - Can also be a 2-D Tensor of shape (batch_size, num_frames). + a 1-D Tensor of N indices, one per batch element. These may be fractional. Can also be a 2-D Tensor of + shape (batch_size, num_frames). embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index e4d1e745aaba..ebdd753e6393 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -15,7 +15,6 @@ import html import math import re -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy @@ -675,7 +674,9 @@ def __call__( latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :]) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] + ) * noise_factor ) timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition @@ -732,7 +733,6 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - #latents = latents.unsqueeze(0) else: # Long video generation overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 @@ -754,7 +754,9 @@ def __call__( if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size - logger.warning("The length of prefix video latents is truncated for the causal block size alignment.") + logger.warning( + "The length of prefix video latents is truncated for the causal block size alignment." + ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] @@ -817,7 +819,9 @@ def __call__( latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :]) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] + ) * noise_factor ) timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index ff0f941c8acf..4d9198dc3456 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -232,7 +232,7 @@ def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: pixels from saturation at each step. We find that dynamic thresholding results in significantly better photorealism as well as better image-text alignment, especially when using very large guidance weights." - https://arxiv.org/abs/2205.11487 + https://huggingface.co/papers/2205.11487 """ dtype = sample.dtype batch_size, channels, *remaining_dims = sample.shape From ae6adbe177528fb8927b7c0bf7fe0f8298244670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 24 May 2025 16:07:36 +0300 Subject: [PATCH 101/264] fix number of frames for long video generation --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index ebdd753e6393..c643cd28e3e3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -789,7 +789,7 @@ def __call__( num_channels_latents, height, width, - num_frames, + (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, torch.float32, device, generator, From 9afb214ab950b675d3e55b93aa57437a46cff02a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 24 May 2025 17:03:10 +0300 Subject: [PATCH 102/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index c643cd28e3e3..1f7948a40089 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -755,7 +755,7 @@ def __call__( if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size logger.warning( - "The length of prefix video latents is truncated for the causal block size alignment." + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] From f1483ad88df9b857c6a497ecee6034be6af825f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 24 May 2025 17:04:38 +0300 Subject: [PATCH 103/264] fix: `latents` initialization for long video generation in processing consecutive frames --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 1f7948a40089..8569197ff9d8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -793,7 +793,7 @@ def __call__( torch.float32, device, generator, - latents, + None if i > 0 else latents, ) if prefix_video_latents is not None: latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) From a16c31b867da51c2f481205d11005e969af94e57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 24 May 2025 17:33:53 +0300 Subject: [PATCH 104/264] update templates --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 286 +++++++++--------- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 279 ++++++++--------- 2 files changed, 283 insertions(+), 282 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 5e96e4cb3264..603c9b10eda6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -15,11 +15,9 @@ import html import math import re -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy -import numpy as np import torch from transformers import AutoTokenizer, UMT5EncoderModel @@ -346,26 +344,26 @@ def prepare_latents( def generate_timestep_matrix( self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, + num_frames: int, + step_template: torch.Tensor, + base_num_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size + num_frames_block = num_frames // causal_block_size + base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: infer_step_num = len(step_template) gen_block = base_num_frames_block min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -375,7 +373,7 @@ def generate_timestep_matrix( ) # to handle the counter in row works starting from 1 pre_row = torch.zeros(num_frames_block, dtype=torch.long) if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations + pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): new_row = torch.zeros(num_frames_block, dtype=torch.long) @@ -403,7 +401,7 @@ def generate_timestep_matrix( update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): + for curr_mask in update_mask: if terminal_flag < num_frames_block and curr_mask[terminal_flag]: terminal_flag += 1 @@ -413,11 +411,11 @@ def generate_timestep_matrix( step_index = torch.stack(step_index, dim=0) step_matrix = torch.stack(step_matrix, dim=0) - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + if causal_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] return step_matrix, step_index, step_update_mask, valid_interval @@ -473,7 +471,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, overlap_history: Optional[int] = 17, - shift: float = 1.0, # TODO: check this + shift: float = 8.0, addnoise_condition: float = 20.0, base_num_frames: int = 97, ar_step: int = 5, @@ -540,10 +538,10 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `1.0`): + shift (`float`, *optional*, defaults to `8.0`): overlap_history (`int`, *optional*, defaults to `17`): Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20`): + addnoise_condition (`float`, *optional*, defaults to `20.0`): Improves consistency in long video generation base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) @@ -553,7 +551,9 @@ def __call__( Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): - Examples: Returns: + Examples: + + Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s @@ -616,42 +616,34 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) - - prompt_embeds = prompt_embeds.to(transformer_dtype) - if negative_prompt_embeds is not None: - negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps - prefix_video = None - predix_video_latent_length = 0 + prefix_video_latents_length = 0 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) if causal_block_size is None: - causal_block_size = self.transformer.num_frame_per_block + causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - timesteps = self.scheduler.timesteps - latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - sample_schedulers = [self.scheduler] - for _ in range(latent_length - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [] + for _ in range(num_latent_frames): + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) # 5. Prepare latent variables @@ -676,37 +668,38 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i - - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] + ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -716,18 +709,20 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[:, idx], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] @@ -737,7 +732,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -752,45 +747,44 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = self.vae.decode(x0) - videos = [(videos / 2 + 0.5).clamp(0, 1)] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - output_video = None + overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + video = None + prefix_video_latents = None + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] + if video is not None: + prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist + prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num + left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - if ar_step > 0: - num_steps = ( - num_inference_steps - + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step - ) - self.transformer.num_steps = num_steps - else: # i == 0 + else: base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - timesteps = self.scheduler.timesteps - sample_schedulers = [self.scheduler] - for _ in range(base_num_frames_iter - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( @@ -798,7 +792,7 @@ def __call__( timesteps, base_num_frames_iter, ar_step, - predix_video_latent_length, + prefix_video_latents_length, causal_block_size, ) @@ -809,51 +803,48 @@ def __call__( num_channels_latents, height, width, - num_frames, + (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, torch.float32, device, generator, - latents, + None if i > 0 else latents, ) - if prefix_video is not None: - latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + if prefix_video_latents is not None: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = ( - timestep_for_noised_condition - ) + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -863,17 +854,20 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[:, idx], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] @@ -883,7 +877,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -900,32 +894,32 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = [self.vae.decode(x0)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents, return_dict=False)[0] + if video is None: + video = videos else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + else: + video = latents self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + if overlap_history is None: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 9eea5e24b13b..2b3aef021b62 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -15,11 +15,9 @@ import html import math import re -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy -import numpy as np import torch from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel @@ -346,26 +344,26 @@ def prepare_latents( def generate_timestep_matrix( self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, + num_frames: int, + step_template: torch.Tensor, + base_num_frames: int, + ar_step: int = 5, + num_pre_ready: int = 0, + causal_block_size: int = 1, + shrink_interval_with_mask: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size + num_frames_block = num_frames // causal_block_size + base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: infer_step_num = len(step_template) gen_block = base_num_frames_block min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) + # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -375,7 +373,7 @@ def generate_timestep_matrix( ) # to handle the counter in row works starting from 1 pre_row = torch.zeros(num_frames_block, dtype=torch.long) if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations + pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): new_row = torch.zeros(num_frames_block, dtype=torch.long) @@ -403,7 +401,7 @@ def generate_timestep_matrix( update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): + for curr_mask in update_mask: if terminal_flag < num_frames_block and curr_mask[terminal_flag]: terminal_flag += 1 @@ -413,11 +411,11 @@ def generate_timestep_matrix( step_index = torch.stack(step_index, dim=0) step_matrix = torch.stack(step_matrix, dim=0) - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] + if causal_block_size > 1: + step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous() + valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval] return step_matrix, step_index, step_update_mask, valid_interval @@ -471,7 +469,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, overlap_history: Optional[int] = 17, - shift: float = 1.0, # TODO: check this + shift: float = 8.0, addnoise_condition: float = 20.0, base_num_frames: int = 97, ar_step: int = 5, @@ -538,10 +536,10 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `1.0`): + shift (`float`, *optional*, defaults to `8.0`): overlap_history (`int`, *optional*, defaults to `17`): Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20`): + addnoise_condition (`float`, *optional*, defaults to `20.0`): Improves consistency in long video generation base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) @@ -551,12 +549,15 @@ def __call__( Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): - Examples: Returns: + Examples: + + Returns: [`~SkyReelsV2PipelineOutput`] or `tuple`: If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -616,11 +617,20 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - prefix_video = None - predix_video_latent_length = 0 + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + + prefix_video_latents_length = 0 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) if causal_block_size is None: - causal_block_size = self.transformer.num_frame_per_block + causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] @@ -631,21 +641,15 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - timesteps = self.scheduler.timesteps - latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) - latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - sample_schedulers = [self.scheduler] - for _ in range(latent_length - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [] + for _ in range(num_latent_frames): + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length + sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) # 5. Prepare latent variables @@ -656,6 +660,7 @@ def __call__( num_channels_latents, height, width, + num_frames, torch.float32, device, generator, @@ -664,37 +669,38 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) + num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i - - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[:, valid_interval_start:predix_video_latent_length]) + + torch.randn_like( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] + ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -704,18 +710,20 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[:, idx], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] @@ -725,7 +733,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -740,45 +748,44 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = self.vae.decode(x0) - videos = [(videos / 2 + 0.5).clamp(0, 1)] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - video = [video.cpu().numpy().astype(np.uint8) for video in videos] else: # Long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - output_video = None + overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + video = None + prefix_video_latents = None + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(prompt_embeds.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] + if video is not None: + prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist + prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] + finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num + left_frame_num = num_latent_frames - finished_frame_num base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - if ar_step > 0: - num_steps = ( - num_inference_steps - + ((base_num_frames_iter - overlap_history_frames) // causal_block_size - 1) * ar_step - ) - self.transformer.num_steps = num_steps - else: # i == 0 + else: base_num_frames_iter = base_num_frames # 4. Prepare sample schedulers and timestep matrix - self.scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) - timesteps = self.scheduler.timesteps - sample_schedulers = [self.scheduler] - for _ in range(base_num_frames_iter - 1): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=prompt_embeds.device, shift=shift) + sample_schedulers = [] + for _ in range(base_num_frames_iter): + sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( @@ -786,7 +793,7 @@ def __call__( timesteps, base_num_frames_iter, ar_step, - predix_video_latent_length, + prefix_video_latents_length, causal_block_size, ) @@ -797,51 +804,48 @@ def __call__( num_channels_latents, height, width, - num_frames, + (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, torch.float32, device, generator, - latents, + None if i > 0 else latents, ) - if prefix_video is not None: - latents[:, :predix_video_latent_length] = prefix_video[0].to(transformer_dtype) + if prefix_video_latents is not None: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, timestep_i in enumerate(step_matrix): + for i, t in enumerate(step_matrix): if self.interrupt: continue - self._current_timestep = timestep_i - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() + self._current_timestep = t + valid_interval_start, valid_interval_end = valid_interval[i] latent_model_input = ( - latents[:, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() + latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() ) - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: + timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() + + if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] * (1.0 - noise_factor) + torch.randn_like( - latent_model_input[:, valid_interval_start:predix_video_latent_length] + latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] ) * noise_factor ) - timestep[:, valid_interval_start:predix_video_latent_length] = ( - timestep_for_noised_condition - ) + timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -851,17 +855,20 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + update_mask_i = step_update_mask[i] for idx in range(valid_interval_start, valid_interval_end): if update_mask_i[idx].item(): - latents[:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[:, idx], + latents[:, :, idx, :, :] = sample_schedulers[idx].step( + noise_pred[:, :, idx - valid_interval_start, :, :], + t[idx], + latents[:, :, idx, :, :], return_dict=False, generator=generator, )[0] @@ -871,7 +878,7 @@ def __call__( callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep_i, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) @@ -888,32 +895,32 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - x0 = latents.unsqueeze(0) - videos = [self.vae.decode(x0)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents, return_dict=False)[0] + if video is None: + video = videos else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - output_video = [(output_video / 2 + 0.5).clamp(0, 1)] - output_video = [video.permute(1, 2, 3, 0) * 255 for video in output_video] - video = [video.cpu().numpy().astype(np.uint8) for video in output_video] + video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + else: + video = latents self._current_timestep = None if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + if overlap_history is None: + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 3b7b63b11f052a646abe2fc767e8aac529871440 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 24 May 2025 18:19:51 +0300 Subject: [PATCH 105/264] Enhance `convert_skyreelsv2_to_diffusers.py` by adding support for loading sharded model files and updating model configuration. Refactor model loading logic to accommodate new model types and ensure proper initialization of components such as the tokenizer and scheduler. --- scripts/convert_skyreelsv2_to_diffusers.py | 54 +++++++++++++--------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index aefdf485fee4..a018e7dbc158 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -1,4 +1,5 @@ import argparse +import os import pathlib from typing import Any, Dict @@ -6,7 +7,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers import AutoProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import AutoProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel, AutoTokenizer from diffusers import ( AutoencoderKLWan, @@ -67,9 +68,10 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) - def load_sharded_safetensors(dir: pathlib.Path): - # file_paths = list(dir.glob("model*.safetensors")) + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) state_dict = {} - state_dict.update(load_file(dir)) + for path in file_paths: + state_dict.update(load_file(path)) return state_dict @@ -94,9 +96,9 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } - elif model_type == "SkyReelsV2-T2V-14B": + elif model_type == "SkyReels-V2-DF-14B-720P": config = { - "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", + "model_id": "Skywork/SkyReels-V2-DF-14B-720P", "diffusers_config": { "added_kv_proj_dim": None, "attention_head_dim": 128, @@ -106,6 +108,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "freq_dim": 256, "in_channels": 16, "num_attention_heads": 40, + "inject_sample_info": False, "num_layers": 40, "out_channels": 16, "patch_size": [1, 2, 2], @@ -182,9 +185,16 @@ def convert_transformer(model_type: str): config = get_transformer_config(model_type) diffusers_config = config["diffusers_config"] model_id = config["model_id"] - model_dir = hf_hub_download(model_id, "model.safetensors") - original_state_dict = load_sharded_safetensors(model_dir) + if model_type == "SkyReels-V2-DF-1.3B-540P": + original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) + elif model_type == "SkyReels-V2-DF-14B-720P": + os.makedirs(model_type, exist_ok=True) + model_dir = pathlib.Path(model_type) + for i in range(1, 7): + shard_path = f"diffusion_pytorch_model-{i:05d}-of-00006.safetensors" + hf_hub_download(model_id, shard_path, local_dir=model_dir) + original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config) @@ -429,13 +439,13 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] - #transformer = convert_transformer(args.model_type).to(dtype=dtype) - #vae = convert_vae() + transformer = convert_transformer(args.model_type).to(dtype=dtype) + vae = convert_vae() text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") - #tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") - #scheduler = FlowMatchUniPCMultistepScheduler( - # prediction_type="flow_prediction", num_train_timesteps=1000, - #) + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + scheduler = FlowMatchUniPCMultistepScheduler( + prediction_type="flow_prediction", num_train_timesteps=1000, + ) if "I2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( @@ -444,20 +454,20 @@ def get_args(): image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyreelsV2ImageToVideoPipeline( transformer=transformer, - text_encoder=None, - tokenizer=None, - vae=None, - scheduler=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, image_encoder=image_encoder, image_processor=image_processor, ) else: pipe = SkyReelsV2DiffusionForcingPipeline( - transformer=None, + transformer=transformer, text_encoder=text_encoder, - tokenizer=None, - vae=None, - scheduler=None, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, ) # pipe.push_to_hub pipe.save_pretrained( @@ -465,5 +475,5 @@ def get_args(): safe_serialization=True, max_shard_size="5GB", push_to_hub=True, - repo_id="tolgacangoz/SkyReels-V2-DF-1.3B-540P-Diffusers-2", + repo_id=f"tolgacangoz/{args.model_type}-Diffusers", ) From 5e1126d0ea3ea5b89460db68c96ffaa37378dfce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 01:48:00 +0300 Subject: [PATCH 106/264] Update model configuration in `convert_skyreelsv2_to_diffusers.py` to support new model type `SkyReelsV2-DF-14B-540P`. Adjusted parameters including `in_channels`, `added_kv_proj_dim`, and `inject_sample_info`. Refactored sharded model loading logic to accommodate varying shard counts based on model type. --- scripts/convert_skyreelsv2_to_diffusers.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index a018e7dbc158..619782178949 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -116,19 +116,19 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } - elif model_type == "SkyReelsV2-I2V-14B-480p": + elif model_type == "SkyReelsV2-DF-14B-540P": config = { - "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", + "model_id": "Skywork/SkyReelsV2-DF-14B-540P", "diffusers_config": { - "image_dim": 1280, - "added_kv_proj_dim": 5120, + "added_kv_proj_dim": None, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, "ffn_dim": 13824, "freq_dim": 256, - "in_channels": 36, + "in_channels": 16, "num_attention_heads": 40, + "inject_sample_info": False, "num_layers": 40, "out_channels": 16, "patch_size": [1, 2, 2], @@ -188,11 +188,12 @@ def convert_transformer(model_type: str): if model_type == "SkyReels-V2-DF-1.3B-540P": original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) - elif model_type == "SkyReels-V2-DF-14B-720P": + elif model_type in ["SkyReels-V2-DF-14B-720P", "SkyReelsV2-DF-14B-540P"]: os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) - for i in range(1, 7): - shard_path = f"diffusion_pytorch_model-{i:05d}-of-00006.safetensors" + top_shard = 6 if model_type == "SkyReels-V2-DF-14B-720P" else 12 + for i in range(1, top_shard + 1): + shard_path = f"diffusion_pytorch_model-{i:05d}-of-000{top_shard}.safetensors" hf_hub_download(model_id, shard_path, local_dir=model_dir) original_state_dict = load_sharded_safetensors(model_dir) @@ -469,7 +470,7 @@ def get_args(): vae=vae, scheduler=scheduler, ) - # pipe.push_to_hub + pipe.save_pretrained( args.output_path, safe_serialization=True, From 820d415173255148f05225fd6d01eb396eae2491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 01:48:26 +0300 Subject: [PATCH 107/264] Refactor `set_ar_attention` method in `SkyReelsV2Transformer3DModel` to use `register_to_config` for setting configuration parameters. This change improves clarity and maintains consistency in model configuration handling. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index e26785af4d0f..e908dffcd281 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -566,8 +566,7 @@ def forward( return Transformer2DModelOutput(sample=output) def set_ar_attention(self, causal_block_size): - self.config.num_frame_per_block = causal_block_size - self.config.flag_causal_attention = True + self.register_to_config(num_frame_per_block=causal_block_size, flag_causal_attention=True) for block in self.blocks: block.set_ar_attention() From 528e0d79ba0b4352d9b2482303b75cb392fd4a92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 10:35:53 +0300 Subject: [PATCH 108/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 619782178949..0b957d787480 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -191,9 +191,15 @@ def convert_transformer(model_type: str): elif model_type in ["SkyReels-V2-DF-14B-720P", "SkyReelsV2-DF-14B-540P"]: os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) - top_shard = 6 if model_type == "SkyReels-V2-DF-14B-720P" else 12 + if model_type == "SkyReels-V2-DF-14B-720P": + top_shard = 6 + model_name = "diffusion_pytorch_model" + elif model_type == "SkyReelsV2-DF-14B-540P": + top_shard = 12 + model_name = "model" + for i in range(1, top_shard + 1): - shard_path = f"diffusion_pytorch_model-{i:05d}-of-000{top_shard}.safetensors" + shard_path = f"{model_name}-{i:05d}-of-000{top_shard}.safetensors" hf_hub_download(model_id, shard_path, local_dir=model_dir) original_state_dict = load_sharded_safetensors(model_dir) From 6c4301c113da5c6e56f377af188886b09ef846be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 10:39:50 +0300 Subject: [PATCH 109/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 0b957d787480..0047b79e4c31 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -116,9 +116,9 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } - elif model_type == "SkyReelsV2-DF-14B-540P": + elif model_type == "SkyReels-V2-DF-14B-540P": config = { - "model_id": "Skywork/SkyReelsV2-DF-14B-540P", + "model_id": "Skywork/SkyReels-V2-DF-14B-540P", "diffusers_config": { "added_kv_proj_dim": None, "attention_head_dim": 128, From 7d5328fe7fb33680288a9054595864b449c35bfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 10:41:35 +0300 Subject: [PATCH 110/264] upp --- scripts/convert_skyreelsv2_to_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 0047b79e4c31..900577b9c121 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -136,7 +136,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } - elif model_type == "SkyReelsV2-I2V-14B-720p": + elif model_type == "SkyReels-V2-I2V-14B-720p": config = { "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", "diffusers_config": { @@ -188,13 +188,13 @@ def convert_transformer(model_type: str): if model_type == "SkyReels-V2-DF-1.3B-540P": original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) - elif model_type in ["SkyReels-V2-DF-14B-720P", "SkyReelsV2-DF-14B-540P"]: + elif model_type in ["SkyReels-V2-DF-14B-720P", "SkyReels-V2-DF-14B-540P"]: os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) if model_type == "SkyReels-V2-DF-14B-720P": top_shard = 6 model_name = "diffusion_pytorch_model" - elif model_type == "SkyReelsV2-DF-14B-540P": + elif model_type == "SkyReels-V2-DF-14B-540P": top_shard = 12 model_name = "model" From 00849fd4218efdde8e1b4d2445694e4941547ff2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 10:52:47 +0300 Subject: [PATCH 111/264] fix file name --- scripts/convert_skyreelsv2_to_diffusers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 900577b9c121..f1f8ee1ab472 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -68,7 +68,10 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) - def load_sharded_safetensors(dir: pathlib.Path): - file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + if "SkyReels-V2-DF-14B-720P" in str(dir): + file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) + else: + file_paths = list(dir.glob("model*.safetensors")) state_dict = {} for path in file_paths: state_dict.update(load_file(path)) From 8e34d895bc9ae8aeb885c3f5022f0b5c25d97362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 14:13:21 +0300 Subject: [PATCH 112/264] Update `SkyReelsV2Transformer3DModel` to conditionally apply `causal_mask` based on configuration flag. This change enhances flexibility in model behavior during training and inference. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index e908dffcd281..7748a1aa0e02 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -506,7 +506,7 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None ) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) @@ -531,7 +531,7 @@ def forward( timestep_proj = timestep_proj.transpose(1, 2).contiguous() for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask) + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None) # 5. Output norm, projection & unpatchify if temb.dim() == 2: From a6f0d119dd5f361a9619115f67c30f6fbffd8285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 14:16:48 +0300 Subject: [PATCH 113/264] style --- scripts/convert_skyreelsv2_to_diffusers.py | 5 ++-- .../transformers/transformer_skyreels_v2.py | 15 ++++++++++-- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 24 ++++--------------- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 21 +++------------- 4 files changed, 24 insertions(+), 41 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index f1f8ee1ab472..356d7489609d 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -7,7 +7,7 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers import AutoProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel, AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel from diffusers import ( AutoencoderKLWan, @@ -454,7 +454,8 @@ def get_args(): text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") scheduler = FlowMatchUniPCMultistepScheduler( - prediction_type="flow_prediction", num_train_timesteps=1000, + prediction_type="flow_prediction", + num_train_timesteps=1000, ) if "I2V" in args.model_type: diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 7748a1aa0e02..847526bef92b 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -506,7 +506,12 @@ def forward( if torch.is_grad_enabled() and self.gradient_checkpointing: for block in self.blocks: hidden_states = self._gradient_checkpointing_func( - block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + causal_mask if self.config.flag_causal_attention else None, ) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) @@ -531,7 +536,13 @@ def forward( timestep_proj = timestep_proj.transpose(1, 2).contiguous() for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + causal_mask if self.config.flag_causal_attention else None, + ) # 5. Output norm, projection & unpatchify if temb.dim() == 2: diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 603c9b10eda6..8569197ff9d8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -22,7 +22,6 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler @@ -96,9 +95,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -447,8 +446,7 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - image: PipelineImageInput, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, @@ -460,8 +458,6 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, - image_embeds: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -567,12 +563,10 @@ def __call__( self.check_inputs( prompt, negative_prompt, - image, height, width, prompt_embeds, negative_prompt_embeds, - image_embeds, callback_on_step_end_tensor_inputs, ) @@ -610,7 +604,6 @@ def __call__( device=device, ) - # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: @@ -647,14 +640,8 @@ def __call__( ) # 5. Prepare latent variables - num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( - device, dtype=torch.float32 - ) - latents, condition = self.prepare_latents( - image, + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -664,7 +651,6 @@ def __call__( device, generator, latents, - last_image, ) # 6. Denoising loop diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 2b3aef021b62..8569197ff9d8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -19,7 +19,6 @@ import ftfy import torch -from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -96,9 +95,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -447,8 +446,7 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - video: List[Image.Image] = None, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, @@ -561,18 +559,12 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial - width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial - num_videos_per_prompt = 1 - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt, height, width, - video, - latents, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, @@ -634,11 +626,6 @@ def __call__( fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - if latents is None: - video = self.video_processor.preprocess_video(video, height=height, width=width).to( - device, dtype=torch.float32 - ) - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation # 4. Prepare sample schedulers and timestep matrix @@ -655,7 +642,6 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( - video, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -665,7 +651,6 @@ def __call__( device, generator, latents, - latent_timestep, ) # 6. Denoising loop From cc0660cdb62ae1231bd375ce15282461ab7b1304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 18:36:59 +0300 Subject: [PATCH 114/264] Fix class name casing for SkyReelsV2 components in multiple files to ensure consistency and correct functionality. --- scripts/convert_skyreelsv2_to_diffusers.py | 4 ++-- src/diffusers/__init__.py | 16 ++++++++-------- .../dummy_torch_and_transformers_objects.py | 8 ++++---- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 356d7489609d..60fe1b0b0112 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -15,7 +15,7 @@ SkyReelsV2Transformer3DModel, ) from diffusers.pipelines import SkyReelsV2DiffusionForcingPipeline -from diffusers.utils.dummy_torch_and_transformers_objects import SkyreelsV2ImageToVideoPipeline +from diffusers.utils.dummy_torch_and_transformers_objects import SkyReelsV2ImageToVideoPipeline TRANSFORMER_KEYS_RENAME_DICT = { @@ -463,7 +463,7 @@ def get_args(): "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 ) image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - pipe = SkyreelsV2ImageToVideoPipeline( + pipe = SkyReelsV2ImageToVideoPipeline( transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 5eebde9bdb05..813d9c9046ac 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -447,10 +447,10 @@ "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", - "SkyreelsV2DiffusionForcingImageToVideoPipeline", - "SkyreelsV2DiffusionForcingPipeline", - "SkyreelsV2ImageToVideoPipeline", - "SkyreelsV2Pipeline", + "SkyReelsV2DiffusionForcingImageToVideoPipeline", + "SkyReelsV2DiffusionForcingPipeline", + "SkyReelsV2ImageToVideoPipeline", + "SkyReelsV2Pipeline", "StableAudioPipeline", "StableAudioProjectionModel", "StableCascadeCombinedPipeline", @@ -1037,10 +1037,10 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, - SkyreelsV2DiffusionForcingImageToVideoPipeline, - SkyreelsV2DiffusionForcingPipeline, - SkyreelsV2ImageToVideoPipeline, - SkyreelsV2Pipeline, + SkyReelsV2DiffusionForcingImageToVideoPipeline, + SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2ImageToVideoPipeline, + SkyReelsV2Pipeline, StableAudioPipeline, StableAudioProjectionModel, StableCascadeCombinedPipeline, diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 49612750e4ff..f5ce5c5ae515 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1682,7 +1682,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SkyreelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject): +class SkyReelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1697,7 +1697,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SkyreelsV2DiffusionForcingPipeline(metaclass=DummyObject): +class SkyReelsV2DiffusionForcingPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1712,7 +1712,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SkyreelsV2ImageToVideoPipeline(metaclass=DummyObject): +class SkyReelsV2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1727,7 +1727,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SkyreelsV2Pipeline(metaclass=DummyObject): +class SkyReelsV2Pipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): From 14d8d7a777fdc6b2d913d960d0c0c757c6e8d4e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 18:41:14 +0300 Subject: [PATCH 115/264] cleaning --- scripts/convert_skyreelsv2_to_diffusers.py | 46 +--------------------- 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 60fe1b0b0112..29591d2da2fc 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -12,10 +12,10 @@ from diffusers import ( AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, + SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2ImageToVideoPipeline, SkyReelsV2Transformer3DModel, ) -from diffusers.pipelines import SkyReelsV2DiffusionForcingPipeline -from diffusers.utils.dummy_torch_and_transformers_objects import SkyReelsV2ImageToVideoPipeline TRANSFORMER_KEYS_RENAME_DICT = { @@ -139,48 +139,6 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } - elif model_type == "SkyReels-V2-I2V-14B-720p": - config = { - "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", - "diffusers_config": { - "image_dim": 1280, - "added_kv_proj_dim": 5120, - "attention_head_dim": 128, - "cross_attn_norm": True, - "eps": 1e-06, - "ffn_dim": 13824, - "freq_dim": 256, - "in_channels": 36, - "num_attention_heads": 40, - "num_layers": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "qk_norm": "rms_norm_across_heads", - "text_dim": 4096, - }, - } - elif model_type == "Wan-FLF2V-14B-720P": - config = { - "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder - "diffusers_config": { - "image_dim": 1280, - "added_kv_proj_dim": 5120, - "attention_head_dim": 128, - "cross_attn_norm": True, - "eps": 1e-06, - "ffn_dim": 13824, - "freq_dim": 256, - "in_channels": 36, - "num_attention_heads": 40, - "num_layers": 40, - "out_channels": 16, - "patch_size": [1, 2, 2], - "qk_norm": "rms_norm_across_heads", - "text_dim": 4096, - "rope_max_seq_len": 1024, - "pos_embed_seq_len": 257 * 2, - }, - } return config From 85a1f908efae8a5d660c688e68f3b552d3a292be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 25 May 2025 18:48:24 +0300 Subject: [PATCH 116/264] cleansing --- .../transformers/transformer_skyreels_v2.py | 40 +------------------ 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 847526bef92b..324d343238bf 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention.flex_attention import BlockMask, create_block_mask from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -323,7 +322,7 @@ def forward( ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - return hidden_states # TODO: check .to(torch.bfloat16) + return hidden_states def set_ar_attention(self): self.attn1.processor.set_ar_attention() @@ -580,40 +579,3 @@ def set_ar_attention(self, causal_block_size): self.register_to_config(num_frame_per_block=causal_block_size, flag_causal_attention=True) for block in self.blocks: block.set_ar_attention() - - @staticmethod - def _prepare_blockwise_causal_attn_mask( - device: Union[torch.device, str], num_frames: int = 21, frame_seqlen: int = 1560, num_frame_per_block=1 - ) -> BlockMask: - """ - we will divide the token sequence into the following format [1 latent frame] [1 latent frame] ... [1 latent - frame] We use flexattention to construct the attention mask - """ - total_length = num_frames * frame_seqlen - - # we do right padding to get to a multiple of 128 - padded_length = math.ceil(total_length / 128) * 128 - total_length - - ends = torch.zeros(total_length + padded_length, device=device, dtype=torch.long) - - # Block-wise causal mask will attend to all elements that are before the end of the current chunk - frame_indices = torch.arange(start=0, end=total_length, step=frame_seqlen * num_frame_per_block, device=device) - - for tmp in frame_indices: - ends[tmp : tmp + frame_seqlen * num_frame_per_block] = tmp + frame_seqlen * num_frame_per_block - - def attention_mask(b, h, q_idx, kv_idx): - return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) - # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask - - block_mask = create_block_mask( - attention_mask, - B=None, - H=None, - Q_LEN=total_length + padded_length, - KV_LEN=total_length + padded_length, - _compile=False, - device=device, - ) - - return block_mask From 5264ac9882221284cacb0c05f8b87b439aeeec85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 14:04:06 +0300 Subject: [PATCH 117/264] Refactor `get_timestep_embedding` to move modifications into `SkyReelsV2TimeTextImageEmbedding`. --- src/diffusers/models/embeddings.py | 22 +++++-------------- .../transformers/transformer_skyreels_v2.py | 4 +++- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 39ae3ebe5d37..4dc586d39017 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -37,8 +37,7 @@ def get_timestep_embedding( Args timesteps (torch.Tensor): - a 1-D Tensor of N indices, one per batch element. These may be fractional. Can also be a 2-D Tensor of - shape (batch_size, num_frames). + a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): @@ -50,25 +49,18 @@ def get_timestep_embedding( max_period (int): Controls the maximum frequency of the embeddings Returns - torch.Tensor: an [N x dim] Tensor of positional embeddings. If input was 2D, shape is [B x F x dim]. + torch.Tensor: an [N x dim] Tensor of positional embeddings. """ - original_shape = timesteps.shape - if len(original_shape) == 2: - # timesteps is (B, F_v), flatten to (B * F_v) - timesteps_flat = timesteps.reshape(-1) - elif len(original_shape) == 1: - timesteps_flat = timesteps - else: - raise ValueError(f"Timesteps should be 1D or 2D, but got shape {original_shape}") + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps_flat.device + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) - emb = timesteps_flat[:, None].float() * emb[None, :] + emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb @@ -84,10 +76,6 @@ def get_timestep_embedding( if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - if len(original_shape) == 2: - # Reshape back to (B, F_v, embedding_dim) - emb = emb.reshape(original_shape[0], original_shape[1], embedding_dim) - return emb diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 324d343238bf..3fdc78c947c4 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -178,7 +178,9 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, ): - timestep = self.timesteps_proj(timestep) + original_timestep_shape = timestep.shape + timestep = self.timesteps_proj(timestep.reshape(-1)) + timestep = timestep.reshape(*original_timestep_shape, -1) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: From 81acfae8775c230da094cc7591d8ee63bce8b173 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 14:05:16 +0300 Subject: [PATCH 118/264] Remove unnecessary line break in `get_timestep_embedding` function for cleaner code. --- src/diffusers/models/embeddings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4dc586d39017..c25e9997e3fb 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -75,7 +75,6 @@ def get_timestep_embedding( # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb From 11baa0005d571fe6ba220076673b1cb01cbaecde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 14:26:42 +0300 Subject: [PATCH 119/264] Remove `skyreels_v2` entry from `_import_structure` and update its initialization to directly assign the list of SkyReelsV2 components. --- src/diffusers/pipelines/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7f6730cd9254..cd75edabf586 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -30,7 +30,6 @@ "ledits_pp": [], "marigold": [], "pag": [], - "skyreels_v2": [], "stable_diffusion": [], "stable_diffusion_xl": [], } @@ -368,14 +367,12 @@ "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"] - _import_structure["skyreels_v2"].extend( - [ + _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", - ] - ) + ] try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() From 2906c37841046f4d5caccdcb7486b677e42fa397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 16:28:00 +0300 Subject: [PATCH 120/264] cleansing --- .../pipeline_skyreels_v2_diffusion_forcing.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8569197ff9d8..88bfeb34ad63 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -276,6 +276,9 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -308,6 +311,10 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if num_frames > base_num_frames: + if overlap_history is None: + raise ValueError('You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.') + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, @@ -362,7 +369,7 @@ def generate_timestep_matrix( min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -466,12 +473,12 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - overlap_history: Optional[int] = 17, + overlap_history: Optional[int] = None, shift: float = 8.0, - addnoise_condition: float = 20.0, + addnoise_condition: float = 0, base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: Optional[int] = 5, + ar_step: int = 0, + causal_block_size: Optional[int] = 1, fps: int = 24, ): r""" @@ -537,13 +544,13 @@ def __call__( shift (`float`, *optional*, defaults to `8.0`): overlap_history (`int`, *optional*, defaults to `17`): Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20.0`): + addnoise_condition (`float`, *optional*, defaults to `0`): Improves consistency in long video generation base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) - ar_step (`int`, *optional*, defaults to `5`): + ar_step (`int`, *optional*, defaults to `0`): Controls asynchronous inference (0 for synchronous mode) - causal_block_size (`int`, *optional*, defaults to `5`): + causal_block_size (`int`, *optional*, defaults to `1`): Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): @@ -568,8 +575,16 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, ) + if addnoise_condition > 60: + logger.warning( + f'You have set "addnoise_condition" as {addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.' + ) + if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." @@ -634,7 +649,6 @@ def __call__( sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) @@ -712,7 +726,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} @@ -772,7 +785,6 @@ def __call__( sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, @@ -857,7 +869,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} From a38eaab590ccd97e580696f1359fac6d178f0a85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 16:38:18 +0300 Subject: [PATCH 121/264] Refactor attention processing in `SkyReelsV2AttnProcessor2_0` to always convert query, key, and value to `torch.bfloat16`, simplifying the code and improving clarity. --- .../models/transformers/transformer_skyreels_v2.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 3fdc78c947c4..d4560757baa3 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -100,11 +100,10 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states_img = hidden_states_img.type_as(query) if self._flag_ar_attention: - is_self_attention = encoder_hidden_states is hidden_states hidden_states = F.scaled_dot_product_attention( - query.to(torch.bfloat16) if is_self_attention else query, - key.to(torch.bfloat16) if is_self_attention else key, - value.to(torch.bfloat16) if is_self_attention else value, + query.to(torch.bfloat16), + key.to(torch.bfloat16), + value.to(torch.bfloat16), attn_mask=attention_mask, dropout_p=0.0, is_causal=False, From 150ea56c8809ed0df995f43bf39a81aa4a87c9b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 16:57:02 +0300 Subject: [PATCH 122/264] Enhance example usage in `pipeline_skyreels_v2_diffusion_forcing.py` by adding VAE initialization and detailed prompt for video generation, improving clarity and usability of the documentation. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 32 +++++++++++++++---- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 88bfeb34ad63..4f0de4d76137 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -49,17 +49,37 @@ Examples: ```py >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2DiffusionForcingPipeline - >>> from diffusers.utils import export_to_video, load_image + >>> from diffusers import SkyReelsV2DiffusionForcingPipeline, AutoencoderKLWan + >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> vae = AutoencoderKLWan.from_pretrained( + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.float32, + ... subfolder="vae" + ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.bfloat16, ... ) >>> pipe = pipe.to("cuda") - - >>> # TODO + >>> pipe.transformer.set_ar_attention(causal_block_size=5) + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=30, + ... height=544, + ... width=960, + ... guidance_scale=6.0, + ... num_frames=97, + ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + ... generator=torch.Generator(device="cuda").manual_seed(0), + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` """ From ad7d4c40f033c528f078388d45db31ebd3f03aa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 16:59:44 +0300 Subject: [PATCH 123/264] Refactor import structure in `__init__.py` for SkyReelsV2 components and improve formatting in `pipeline_skyreels_v2_diffusion_forcing.py` to enhance code readability and maintainability. --- src/diffusers/pipelines/__init__.py | 8 ++++---- .../pipeline_skyreels_v2_diffusion_forcing.py | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index cd75edabf586..289fc8d298a4 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -368,10 +368,10 @@ ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"] _import_structure["skyreels_v2"] = [ - "SkyReelsV2DiffusionForcingPipeline", - "SkyReelsV2DiffusionForcingImageToVideoPipeline", - "SkyReelsV2ImageToVideoPipeline", - "SkyReelsV2Pipeline", + "SkyReelsV2DiffusionForcingPipeline", + "SkyReelsV2DiffusionForcingImageToVideoPipeline", + "SkyReelsV2ImageToVideoPipeline", + "SkyReelsV2Pipeline", ] try: if not is_onnx_available(): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 4f0de4d76137..d2c77642dc2c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -56,7 +56,7 @@ >>> vae = AutoencoderKLWan.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", ... torch_dtype=torch.float32, - ... subfolder="vae" + ... subfolder="vae", ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", @@ -77,7 +77,7 @@ ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) ... generator=torch.Generator(device="cuda").manual_seed(0), ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos - ... addnoise_condition=20, # Improves consistency in long video generation + ... addnoise_condition=20, # Improves consistency in long video generation ... ).frames[0] >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` @@ -286,7 +286,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -331,9 +330,10 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") - if num_frames > base_num_frames: - if overlap_history is None: - raise ValueError('You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.') + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' + ) # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( From f1ee0245cbd6d5e4617771c39d1a97add64bd2b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 17:06:14 +0300 Subject: [PATCH 124/264] Update `guidance_scale` parameter in `SkyReelsV2DiffusionForcingPipeline` from 5.0 to 6.0 to enhance video generation quality. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index d2c77642dc2c..b2007892c2e0 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -479,7 +479,7 @@ def __call__( width: int = 832, num_frames: int = 97, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, From 421e0dc2e6e42d7372690d754097ec54a8e2c283 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 17:09:30 +0300 Subject: [PATCH 125/264] Update `guidance_scale` parameter in example documentation and class definition of `SkyReelsV2DiffusionForcingPipeline` to ensure consistency and improve video generation quality. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index b2007892c2e0..a5d9387ac827 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -72,10 +72,8 @@ ... num_inference_steps=30, ... height=544, ... width=960, - ... guidance_scale=6.0, ... num_frames=97, ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) - ... generator=torch.Generator(device="cuda").manual_seed(0), ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos ... addnoise_condition=20, # Improves consistency in long video generation ... ).frames[0] @@ -521,7 +519,7 @@ def __call__( num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, defaults to `5.0`): + guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > From 4b688c484b610fe1b795baaf5bab7ae5b87a2813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 18:07:58 +0300 Subject: [PATCH 126/264] Update `causal_block_size` parameter in `SkyReelsV2DiffusionForcingPipeline` to default to `None`. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a5d9387ac827..ffd28d0920c8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -496,7 +496,7 @@ def __call__( addnoise_condition: float = 0, base_num_frames: int = 97, ar_step: int = 0, - causal_block_size: Optional[int] = 1, + causal_block_size: Optional[int] = None, fps: int = 24, ): r""" @@ -568,7 +568,7 @@ def __call__( 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) ar_step (`int`, *optional*, defaults to `0`): Controls asynchronous inference (0 for synchronous mode) - causal_block_size (`int`, *optional*, defaults to `1`): + causal_block_size (`int`, *optional*, defaults to `None`): Recommended when using asynchronous inference (--ar_step > 0) fps (`int`, *optional*, defaults to `24`): From c6b539143a21873dae1ad86173949625d3b0564e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 18:17:15 +0300 Subject: [PATCH 127/264] up --- .../schedulers/scheduling_flow_match_unipc_multistep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py index 4d9198dc3456..d78dbe752275 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_flow_match_unipc_multistep.py @@ -25,7 +25,7 @@ class FlowMatchUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): """ - `FlowMatchUniPCMultistepScheduler` is a ... + `FlowMatchUniPCMultistepScheduler` is a flow matching version of the UniPCMultistepScheduler. This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. From 3bf1e4a2dc9ac9eef893420416666561682df830 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 18:42:35 +0300 Subject: [PATCH 128/264] Fix dtype conversion for `timestep_proj` in `SkyReelsV2Transformer3DModel` to *ensure* correct tensor operations. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d4560757baa3..24b5cd1b2e61 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -517,7 +517,7 @@ def forward( fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() - timestep_proj.to(fps_emb.dtype) + timestep_proj = timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) if flag_df: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat( @@ -525,7 +525,7 @@ def forward( ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) - timestep_proj.to(hidden_states.dtype) + timestep_proj = timestep_proj.to(hidden_states.dtype) if flag_df: b, f = timestep.shape From f48363ca30bdf1e090b67f162a398afd6cfdbb7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 18:45:26 +0300 Subject: [PATCH 129/264] Optimize causal mask generation by replacing repeated tensor with `repeat_interleave` for improved efficiency in `SkyReelsV2Transformer3DModel`. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 24b5cd1b2e61..c292f173676e 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -484,8 +484,7 @@ def forward( if self.config.flag_causal_attention: frame_num, height, width = grid_sizes block_num = frame_num // self.config.num_frame_per_block - range_tensor = torch.arange(block_num, device=hidden_states.device).view(-1, 1) - range_tensor = range_tensor.repeat(1, self.config.num_frame_per_block).flatten() + range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(self.config.num_frame_per_block) causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1) causal_mask = causal_mask.repeat(1, height, width, 1, height, width) From 920d9564eb158000032725c4d5c29b262e2e6418 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 26 May 2025 18:46:00 +0300 Subject: [PATCH 130/264] style --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index c292f173676e..46e58153d819 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -484,7 +484,9 @@ def forward( if self.config.flag_causal_attention: frame_num, height, width = grid_sizes block_num = frame_num // self.config.num_frame_per_block - range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(self.config.num_frame_per_block) + range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave( + self.config.num_frame_per_block + ) causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1) causal_mask = causal_mask.repeat(1, height, width, 1, height, width) From db9cda94e24adbead3144ef197c14d744e1d8ce7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 08:02:01 +0300 Subject: [PATCH 131/264] Enhance example documentation in `SkyReelsV2DiffusionForcingPipeline` with guidance scale and shift parameters for T2V and I2V. Remove unused `retrieve_latents` function to streamline the code. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index ffd28d0920c8..e94d61d1538b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -62,6 +62,8 @@ ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", ... torch_dtype=torch.bfloat16, ... ) + >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) @@ -72,6 +74,7 @@ ... num_inference_steps=30, ... height=544, ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V ... num_frames=97, ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos @@ -99,20 +102,6 @@ def prompt_clean(text): return text -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -524,7 +513,7 @@ def __call__( `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -560,17 +549,21 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `17`): Number of frames to overlap for smooth transitions in long videos addnoise_condition (`float`, *optional*, defaults to `0`): - Improves consistency in long video generation + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) ar_step (`int`, *optional*, defaults to `0`): Controls asynchronous inference (0 for synchronous mode) causal_block_size (`int`, *optional*, defaults to `None`): - Recommended when using asynchronous inference (--ar_step > 0) + Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video Examples: From ff6eeea4408d8878212d0c0b1870cb11ff1482d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 08:10:51 +0300 Subject: [PATCH 132/264] Refactor sample scheduler creation in `SkyReelsV2DiffusionForcingPipeline` to use `deepcopy` for improved state management during inference steps. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index e94d61d1538b..af7c7c87eaa7 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -15,6 +15,7 @@ import html import math import re +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy @@ -657,7 +658,7 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( @@ -793,7 +794,7 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(base_num_frames_iter): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( From c0abccc950da8630b8fe0affe091cfadc8fc7877 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 09:27:26 +0300 Subject: [PATCH 133/264] Enhance error handling and documentation in `SkyReelsV2DiffusionForcingPipeline` for `overlap_history` and `addnoise_condition` parameters to improve long video generation guidance. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index af7c7c87eaa7..76f0f78bd28d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -319,6 +319,10 @@ def check_inputs( raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) raise ValueError( 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' ) @@ -551,8 +555,9 @@ def __call__( The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) - overlap_history (`int`, *optional*, defaults to `17`): - Number of frames to overlap for smooth transitions in long videos + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. addnoise_condition (`float`, *optional*, defaults to `0`): This is used to help smooth the long video generation by adding some noise to the clean condition. Too large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger @@ -594,7 +599,7 @@ def __call__( if addnoise_condition > 60: logger.warning( - f'You have set "addnoise_condition" as {addnoise_condition}. The value is too large which can cause inconsistency in long video generation. The value is recommanded to set 20.' + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -780,7 +785,9 @@ def __call__( if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] From 35061d0a51a2d91fc9ba569678133234adb52ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 12:49:35 +0300 Subject: [PATCH 134/264] Update documentation and progress bar handling in `SkyReelsV2DiffusionForcingPipeline` to clarify asynchronous inference settings and improve progress tracking during denoising steps. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 76f0f78bd28d..4b06cbb0b146 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -565,7 +565,11 @@ def __call__( base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) ar_step (`int`, *optional*, defaults to `0`): - Controls asynchronous inference (0 for synchronous mode) + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): @@ -685,8 +689,9 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) + progress_bar_step = len(timesteps) / len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): @@ -758,7 +763,7 @@ def __call__( if i == len(step_matrix) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): - progress_bar.update() + progress_bar.update(progress_bar_step) if XLA_AVAILABLE: xm.mark_step() @@ -830,8 +835,9 @@ def __call__( latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) + progress_bar_step = len(timesteps) / len(step_matrix) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): @@ -905,7 +911,7 @@ def __call__( if i == len(step_matrix) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): - progress_bar.update() + progress_bar.update(progress_bar_step) if XLA_AVAILABLE: xm.mark_step() From cede08c2843236c67957e92d7eca922688b05547 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 13:05:11 +0300 Subject: [PATCH 135/264] Refine progress bar calculation in `SkyReelsV2DiffusionForcingPipeline` by rounding the step size to one decimal place for improved readability during denoising steps. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 4b06cbb0b146..75f533758c7d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -691,7 +691,7 @@ def __call__( # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - progress_bar_step = len(timesteps) / len(step_matrix) + progress_bar_step = round(len(timesteps) / len(step_matrix), 1) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): @@ -837,7 +837,7 @@ def __call__( # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - progress_bar_step = len(timesteps) / len(step_matrix) + progress_bar_step = round(len(timesteps) / len(step_matrix), 1) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(step_matrix): From 5bc9a1be70551777c2048a8e1b8bc436aa2929b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 27 May 2025 16:47:06 +0300 Subject: [PATCH 136/264] Update import statements in `SkyReelsV2DiffusionForcingPipeline` documentation for improved clarity and organization. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 75f533758c7d..449cdba6ae32 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -50,7 +50,11 @@ Examples: ```py >>> import torch - >>> from diffusers import SkyReelsV2DiffusionForcingPipeline, AutoencoderKLWan + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingPipeline, + ... FlowMatchUniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) >>> from diffusers.utils import export_to_video >>> # Load the pipeline From 5c658c9890210613480f965b08a837741e904800 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 10:38:28 +0300 Subject: [PATCH 137/264] Refactor progress bar handling in `SkyReelsV2DiffusionForcingPipeline` to use total steps instead of calculated step size. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 449cdba6ae32..8951526ab02d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -695,9 +695,8 @@ def __call__( # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - progress_bar_step = round(len(timesteps) / len(step_matrix), 1) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -767,7 +766,7 @@ def __call__( if i == len(step_matrix) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): - progress_bar.update(progress_bar_step) + progress_bar.update() if XLA_AVAILABLE: xm.mark_step() @@ -841,9 +840,8 @@ def __call__( # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - progress_bar_step = round(len(timesteps) / len(step_matrix), 1) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -915,7 +913,7 @@ def __call__( if i == len(step_matrix) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): - progress_bar.update(progress_bar_step) + progress_bar.update() if XLA_AVAILABLE: xm.mark_step() From b30a426448f37e3403d65c0bcd8e045c99b20935 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 10:57:50 +0300 Subject: [PATCH 138/264] update templates for i2v, v2v --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 134 +++++++++++------- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 134 +++++++++++------- 2 files changed, 172 insertions(+), 96 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 8569197ff9d8..8951526ab02d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -15,6 +15,7 @@ import html import math import re +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy @@ -49,17 +50,42 @@ Examples: ```py >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2DiffusionForcingPipeline - >>> from diffusers.utils import export_to_video, load_image + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingPipeline, + ... FlowMatchUniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> vae = AutoencoderKLWan.from_pretrained( + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.float32, + ... subfolder="vae", + ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.bfloat16, ... ) + >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") - - >>> # TODO + >>> pipe.transformer.set_ar_attention(causal_block_size=5) + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=30, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` """ @@ -81,20 +107,6 @@ def prompt_clean(text): return text -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -266,7 +278,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -276,6 +287,9 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -308,6 +322,15 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) + raise ValueError( + 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' + ) + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, @@ -362,7 +385,7 @@ def generate_timestep_matrix( min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -452,7 +475,7 @@ def __call__( width: int = 832, num_frames: int = 97, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -466,12 +489,12 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - overlap_history: Optional[int] = 17, + overlap_history: Optional[int] = None, shift: float = 8.0, - addnoise_condition: float = 20.0, + addnoise_condition: float = 0, base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: Optional[int] = 5, + ar_step: int = 0, + causal_block_size: Optional[int] = None, fps: int = 24, ): r""" @@ -494,12 +517,12 @@ def __call__( num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, defaults to `5.0`): + guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -535,17 +558,26 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - overlap_history (`int`, *optional*, defaults to `17`): - Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20.0`): - Improves consistency in long video generation + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. + addnoise_condition (`float`, *optional*, defaults to `0`): + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) - ar_step (`int`, *optional*, defaults to `5`): - Controls asynchronous inference (0 for synchronous mode) - causal_block_size (`int`, *optional*, defaults to `5`): - Recommended when using asynchronous inference (--ar_step > 0) + ar_step (`int`, *optional*, defaults to `0`): + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. + causal_block_size (`int`, *optional*, defaults to `None`): + Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video Examples: @@ -568,8 +600,16 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, ) + if addnoise_condition > 60: + logger.warning( + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." + ) + if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." @@ -631,10 +671,9 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) @@ -654,10 +693,10 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -712,7 +751,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} @@ -755,7 +793,9 @@ def __call__( if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] @@ -769,10 +809,9 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(base_num_frames_iter): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, @@ -799,10 +838,10 @@ def __call__( latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -857,7 +896,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 8569197ff9d8..8951526ab02d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -15,6 +15,7 @@ import html import math import re +from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union import ftfy @@ -49,17 +50,42 @@ Examples: ```py >>> import torch - >>> import PIL.Image - >>> from diffusers import SkyReelsV2DiffusionForcingPipeline - >>> from diffusers.utils import export_to_video, load_image + >>> from diffusers import ( + ... SkyReelsV2DiffusionForcingPipeline, + ... FlowMatchUniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> vae = AutoencoderKLWan.from_pretrained( + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.float32, + ... subfolder="vae", + ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "HF_placeholder/SkyReels-V2-DF-1.3B-540P", torch_dtype=torch.float16 + ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... torch_dtype=torch.bfloat16, ... ) + >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") - - >>> # TODO + >>> pipe.transformer.set_ar_attention(causal_block_size=5) + + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + + >>> output = pipe( + ... prompt=prompt, + ... num_inference_steps=30, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation + ... ).frames[0] + >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` """ @@ -81,20 +107,6 @@ def prompt_clean(text): return text -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" -): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return encoder_output.latent_dist.mode() - elif hasattr(encoder_output, "latents"): - return encoder_output.latents - else: - raise AttributeError("Could not access latents of provided encoder_output") - - class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -266,7 +278,6 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -276,6 +287,9 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, + overlap_history=None, + num_frames=None, + base_num_frames=None, ): if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -308,6 +322,15 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if num_frames > base_num_frames and overlap_history is None: + raise ValueError( + "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " + "Please specify a value for `overlap_history`. Recommended values are 17 or 37." + ) + raise ValueError( + 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' + ) + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, @@ -362,7 +385,7 @@ def generate_timestep_matrix( min_ar_step = infer_step_num / gen_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, causal_block_size, num_frames_block, base_num_frames_block) + step_template = torch.cat( [ torch.tensor([999], dtype=torch.int64, device=step_template.device), @@ -452,7 +475,7 @@ def __call__( width: int = 832, num_frames: int = 97, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -466,12 +489,12 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, - overlap_history: Optional[int] = 17, + overlap_history: Optional[int] = None, shift: float = 8.0, - addnoise_condition: float = 20.0, + addnoise_condition: float = 0, base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: Optional[int] = 5, + ar_step: int = 0, + causal_block_size: Optional[int] = None, fps: int = 24, ): r""" @@ -494,12 +517,12 @@ def __call__( num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, defaults to `5.0`): + guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**) num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): @@ -535,17 +558,26 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - overlap_history (`int`, *optional*, defaults to `17`): - Number of frames to overlap for smooth transitions in long videos - addnoise_condition (`float`, *optional*, defaults to `20.0`): - Improves consistency in long video generation + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) + overlap_history (`int`, *optional*, defaults to `None`): + Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes + short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. + addnoise_condition (`float`, *optional*, defaults to `0`): + This is used to help smooth the long video generation by adding some noise to the clean condition. Too + large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger + ones, but it is recommended to not exceed 50. base_num_frames (`int`, *optional*, defaults to `97`): 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**) - ar_step (`int`, *optional*, defaults to `5`): - Controls asynchronous inference (0 for synchronous mode) - causal_block_size (`int`, *optional*, defaults to `5`): - Recommended when using asynchronous inference (--ar_step > 0) + ar_step (`int`, *optional*, defaults to `0`): + Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous + inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed + to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole + sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous + inference may improve the instruction following and visual consistent performance. + causal_block_size (`int`, *optional*, defaults to `None`): + Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): + Frame rate of the generated video Examples: @@ -568,8 +600,16 @@ def __call__( prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, + overlap_history, + num_frames, + base_num_frames, ) + if addnoise_condition > 60: + logger.warning( + f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended." + ) + if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." @@ -631,10 +671,9 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * num_latent_frames step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) @@ -654,10 +693,10 @@ def __call__( ) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -712,7 +751,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} @@ -755,7 +793,9 @@ def __call__( if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment." + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] @@ -769,10 +809,9 @@ def __call__( # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(base_num_frames_iter): - sample_scheduler = FlowMatchUniPCMultistepScheduler.from_config(self.scheduler.config) + sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( base_num_frames_iter, timesteps, @@ -799,10 +838,10 @@ def __call__( latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop - num_warmup_steps = len(step_matrix) - num_inference_steps * self.scheduler.order + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) - with self.progress_bar(total=num_inference_steps) as progress_bar: + with self.progress_bar(total=len(step_matrix)) as progress_bar: for i, t in enumerate(step_matrix): if self.interrupt: continue @@ -857,7 +896,6 @@ def __call__( return_dict=False, generator=generator, )[0] - sample_schedulers_counter[idx] += 1 if callback_on_step_end is not None: callback_kwargs = {} From 238d07d81cf482cbb744057c095ddaf4cbf64e56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 11:11:50 +0300 Subject: [PATCH 139/264] Add `retrieve_latents` function to streamline latent retrieval in `SkyReelsV2DiffusionForcingPipeline`. Update video latent processing to utilize this new function for improved clarity and maintainability. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8951526ab02d..352bad15bc9a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -107,6 +107,20 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -787,8 +801,8 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist - prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size From d3bd6382bc9a3bc07a3b09bc8b39f1269a76925d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 11:12:03 +0300 Subject: [PATCH 140/264] Add `retrieve_latents` function to both i2v and v2v pipelines for consistent latent retrieval. Update video latent processing to utilize this function, enhancing clarity and maintainability across the SkyReelsV2DiffusionForcingPipeline implementations. --- ...peline_skyreels_v2_diffusion_forcing_i2v.py | 18 ++++++++++++++++-- ...peline_skyreels_v2_diffusion_forcing_v2v.py | 18 ++++++++++++++++-- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 8951526ab02d..352bad15bc9a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -107,6 +107,20 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -787,8 +801,8 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist - prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 8951526ab02d..352bad15bc9a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -107,6 +107,20 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. @@ -787,8 +801,8 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_latents_dist = self.vae.encode(video[:, :, -overlap_history:]).latent_dist - prefix_video_latents = (prefix_latents_dist.mode() - latents_mean) * latents_std + prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size From 2aab1dee0ce5ab84eeca1837b445692f8cd09454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 12:10:13 +0300 Subject: [PATCH 141/264] Remove redundant ValueError for `overlap_history` in `SkyReelsV2DiffusionForcingPipeline` to streamline error handling and improve user guidance for long video generation. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 352bad15bc9a..660549e4a658 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -341,9 +341,6 @@ def check_inputs( "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " "Please specify a value for `overlap_history`. Recommended values are 17 or 37." ) - raise ValueError( - 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' - ) # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( From 8ab5bb1e0bc2e813605ea091c2bfa074005a8743 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 13:08:02 +0300 Subject: [PATCH 142/264] Update default video dimensions and flow matching scheduler parameter in `SkyReelsV2DiffusionForcingPipeline` to enhance video generation capabilities. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 660549e4a658..242536dce42a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -482,8 +482,8 @@ def __call__( self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, + height: int = 544, + width: int = 960, num_frames: int = 97, num_inference_steps: int = 50, guidance_scale: float = 6.0, @@ -519,9 +519,9 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, defaults to `480`): + height (`int`, defaults to `544`): The height of the generated video. - width (`int`, defaults to `832`): + width (`int`, defaults to `960`): The width of the generated video. num_frames (`int`, defaults to `97`): The number of frames in the generated video. @@ -569,7 +569,7 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) + Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. From 323ec66ea9e5a7a576d7c7800e4a28b398fc577e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 13:08:29 +0300 Subject: [PATCH 143/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to support Image-to-Video (i2v) generation. Update class name, add image encoding functionality, and adjust parameters for improved video generation. Enhance error handling for image inputs and update documentation accordingly. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 203 ++++++++++++++---- 1 file changed, 163 insertions(+), 40 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 352bad15bc9a..c8b093671dbb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -16,10 +16,15 @@ import math import re from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import PIL import ftfy import torch +from diffusers.image_processor import PipelineImageInput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import CLIPImageProcessor, CLIPVisionModel from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -27,8 +32,6 @@ from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import SkyReelsV2PipelineOutput @@ -51,7 +54,7 @@ ```py >>> import torch >>> from diffusers import ( - ... SkyReelsV2DiffusionForcingPipeline, + ... SkyReelsV2DiffusionForcingImageToVideoPipeline, ... FlowMatchUniPCMultistepScheduler, ... AutoencoderKLWan, ... ) @@ -63,7 +66,7 @@ ... torch_dtype=torch.float32, ... subfolder="vae", ... ) - >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", ... torch_dtype=torch.bfloat16, ... ) @@ -121,9 +124,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -135,6 +138,11 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`UMT5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. + image_encoder ([`CLIPVisionModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically + the + [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) + variant. transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): @@ -143,13 +151,15 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, + image_encoder: CLIPVisionModel, + image_processor: CLIPImageProcessor, transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, @@ -160,20 +170,23 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, + image_processor=image_processor, ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.image_processor = image_processor # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, + max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -210,6 +223,17 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image + def encode_image( + self, + image: PipelineImageInput, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, @@ -296,15 +320,28 @@ def check_inputs( self, prompt, negative_prompt, + image, height, width, prompt_embeds=None, negative_prompt_embeds=None, + image_embeds=None, callback_on_step_end_tensor_inputs=None, overlap_history=None, num_frames=None, base_num_frames=None, ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -341,13 +378,11 @@ def check_inputs( "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " "Please specify a value for `overlap_history`. Recommended values are 17 or 37." ) - raise ValueError( - 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' - ) - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents + # Copied from diffusers.pipelines.wan.pipeline_wan.WanImageToVideoPipeline.prepare_latents def prepare_latents( self, + image: PipelineImageInput, batch_size: int, num_channels_latents: int = 16, height: int = 480, @@ -357,27 +392,74 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) - + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - ) + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + image = image.unsqueeze(2) + if last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix def generate_timestep_matrix( self, num_frames: int, @@ -483,18 +565,21 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + image: PipelineImageInput, + prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, + height: int = 544, + width: int = 960, num_frames: int = 97, num_inference_steps: int = 50, - guidance_scale: float = 6.0, + guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -504,7 +589,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, overlap_history: Optional[int] = None, - shift: float = 8.0, + shift: float = 3.0, addnoise_condition: float = 0, base_num_frames: int = 97, ar_step: int = 0, @@ -515,6 +600,8 @@ def __call__( The call function to the pipeline for generation. Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -522,16 +609,16 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, defaults to `480`): + height (`int`, defaults to `544`): The height of the generated video. - width (`int`, defaults to `832`): + width (`int`, defaults to `960`): The width of the generated video. num_frames (`int`, defaults to `97`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, defaults to `6.0`): + guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -552,6 +639,12 @@ def __call__( negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. + last_image (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"np"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -571,8 +664,8 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `8.0`): - Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) + shift (`float`, *optional*, defaults to `3.0`): + Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. @@ -609,10 +702,12 @@ def __call__( self.check_inputs( prompt, negative_prompt, + image, height, width, prompt_embeds, negative_prompt_embeds, + image_embeds, callback_on_step_end_tensor_inputs, overlap_history, num_frames, @@ -658,11 +753,21 @@ def __call__( device=device, ) + + # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps @@ -693,8 +798,14 @@ def __call__( ) # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -704,6 +815,7 @@ def __call__( device, generator, latents, + last_image, ) # 6. Denoising loop @@ -738,6 +850,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -748,6 +861,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -836,8 +950,14 @@ def __call__( ) # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( + image, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -847,6 +967,7 @@ def __call__( device, generator, None if i > 0 else latents, + last_image, ) if prefix_video_latents is not None: latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) @@ -883,6 +1004,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -893,6 +1015,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, From ce804adf99c929c7c620938606a933dcf406477e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 17:55:59 +0300 Subject: [PATCH 144/264] Improve organization for image-last_image condition. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 92 +++++++++++-------- 1 file changed, 53 insertions(+), 39 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index c8b093671dbb..627c9a415c43 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -753,26 +753,18 @@ def __call__( device=device, ) - # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) - # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - prefix_video_latents_length = 0 + last_video = None + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = ( (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -787,16 +779,6 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [] - for _ in range(num_latent_frames): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - sample_schedulers.append(sample_scheduler) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size - ) - # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) @@ -817,6 +799,27 @@ def __call__( latents, last_image, ) + prefix_video_latents_length = condition.shape[2] + + latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to(transformer_dtype) + if last_image is not None: + latents = torch.cat([latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2) + base_num_frames += prefix_video_latents_length + num_latent_frames += prefix_video_latents_length + + # 4. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + sample_schedulers.append(sample_scheduler) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size + ) + + if last_image is not None: + step_matrix[:, -prefix_video_latents_length:] = 0 + step_update_mask[:, -prefix_video_latents_length:] = False # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -850,7 +853,6 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -861,7 +863,6 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -934,6 +935,27 @@ def __call__( else: base_num_frames_iter = base_num_frames + latents, condition = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, + torch.float32, + device, + generator, + None if i > 0 else latents, + last_image, + ) + + prefix_video_latents_length = condition.shape[2] + + latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to(transformer_dtype) + if last_video is not None and i == n_iter - 1: + latents = torch.cat([latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2) + base_num_frames_iter += prefix_video_latents_length + # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(base_num_frames_iter): @@ -949,6 +971,10 @@ def __call__( causal_block_size, ) + if last_video is not None and i == n_iter - 1: + step_matrix[:, -last_video_latent_length:] = 0 + step_update_mask[:, -last_video_latent_length:] = False + # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) @@ -956,21 +982,6 @@ def __call__( last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( device, dtype=torch.float32 ) - latents, condition = self.prepare_latents( - image, - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, - torch.float32, - device, - generator, - None if i > 0 else latents, - last_image, - ) - if prefix_video_latents is not None: - latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1004,7 +1015,6 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -1015,7 +1025,6 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - encoder_hidden_states_image=image_embeds, flag_df=True, fps=fps_embeds, attention_kwargs=attention_kwargs, @@ -1055,6 +1064,9 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + if last_video is not None and i == n_iter - 1: + latents = latents[:, :, :-last_video_latent_length] + if not output_type == "latent": latents = latents.to(self.vae.dtype) latents = latents / latents_std + latents_mean @@ -1070,6 +1082,8 @@ def __call__( if not output_type == "latent": if overlap_history is None: + if last_video is not None: + latents = latents[:, :, :-last_video_latent_length] latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) From ff972069a986129e977a84968126f9916d02a0a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 20:01:37 +0300 Subject: [PATCH 145/264] Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` to improve latent preparation and video condition handling integration. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 98 +++++++++---------- 1 file changed, 45 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 627c9a415c43..2ecb15505a11 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -18,14 +18,14 @@ from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import PIL import ftfy +import PIL import torch +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + from diffusers.image_processor import PipelineImageInput from diffusers.utils.torch_utils import randn_tensor from diffusers.video_processor import VideoProcessor -from transformers import CLIPImageProcessor, CLIPVisionModel -from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import WanLoraLoaderMixin @@ -186,7 +186,7 @@ def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_videos_per_prompt: int = 1, - max_sequence_length: int = 512, + max_sequence_length: int = 226, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -379,7 +379,6 @@ def check_inputs( "Please specify a value for `overlap_history`. Recommended values are 17 or 37." ) - # Copied from diffusers.pipelines.wan.pipeline_wan.WanImageToVideoPipeline.prepare_latents def prepare_latents( self, image: PipelineImageInput, @@ -411,16 +410,12 @@ def prepare_latents( latents = latents.to(device=device, dtype=dtype) image = image.unsqueeze(2) - if last_image is None: - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 - ) - else: + if last_image is not None: last_image = last_image.unsqueeze(2) - video_condition = torch.cat( - [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], - dim=2, - ) + video_condition = torch.cat([image, last_image], dim=2) + else: + video_condition = image + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( @@ -444,20 +439,7 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) - - if last_image is None: - mask_lat_size[:, :, list(range(1, num_frames))] = 0 - else: - mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 - first_frame_mask = mask_lat_size[:, :, 0:1] - first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) - mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) - mask_lat_size = mask_lat_size.transpose(1, 2) - mask_lat_size = mask_lat_size.to(latent_condition.device) - - return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + return latents, latent_condition # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix def generate_timestep_matrix( @@ -763,8 +745,6 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - last_video = None - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = ( (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 @@ -799,11 +779,15 @@ def __call__( latents, last_image, ) - prefix_video_latents_length = condition.shape[2] + prefix_video_latents_length = condition.shape[2] // 2 - latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to(transformer_dtype) + latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( + transformer_dtype + ) if last_image is not None: - latents = torch.cat([latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2) + latents = torch.cat( + [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 + ) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length @@ -916,7 +900,9 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: @@ -935,6 +921,16 @@ def __call__( else: base_num_frames_iter = base_num_frames + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=torch.float32 + ) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + latents, condition = self.prepare_latents( image, batch_size * num_videos_per_prompt, @@ -949,11 +945,15 @@ def __call__( last_image, ) - prefix_video_latents_length = condition.shape[2] + prefix_video_latents_length = condition.shape[2] // 2 - latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to(transformer_dtype) - if last_video is not None and i == n_iter - 1: - latents = torch.cat([latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2) + latents[:, :, :prefix_video_latents_length, :, :] = condition[ + :, :, :prefix_video_latents_length, :, : + ].to(transformer_dtype) + if last_image is not None and i == n_iter - 1: + latents = torch.cat( + [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 + ) base_num_frames_iter += prefix_video_latents_length # 4. Prepare sample schedulers and timestep matrix @@ -971,17 +971,9 @@ def __call__( causal_block_size, ) - if last_video is not None and i == n_iter - 1: - step_matrix[:, -last_video_latent_length:] = 0 - step_update_mask[:, -last_video_latent_length:] = False - - # 5. Prepare latent variables - num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( - device, dtype=torch.float32 - ) + if last_image is not None and i == n_iter - 1: + step_matrix[:, -prefix_video_latents_length:] = 0 + step_update_mask[:, -prefix_video_latents_length:] = False # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order @@ -1064,8 +1056,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if last_video is not None and i == n_iter - 1: - latents = latents[:, :, :-last_video_latent_length] + if last_image is not None and i == n_iter - 1: + latents = latents[:, :, :-prefix_video_latents_length] if not output_type == "latent": latents = latents.to(self.vae.dtype) @@ -1082,8 +1074,8 @@ def __call__( if not output_type == "latent": if overlap_history is None: - if last_video is not None: - latents = latents[:, :, :-last_video_latent_length] + if last_image is not None: + latents = latents[:, :, :-prefix_video_latents_length] latents = latents.to(self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) From 5d702cfb41d144371c9a445bb6dfdb56f5d728b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 20:01:55 +0300 Subject: [PATCH 146/264] style --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +++- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 242536dce42a..60477ff8d0ca 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -798,7 +798,9 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 352bad15bc9a..e74b2c32ae9b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -801,7 +801,9 @@ def __call__( ) for i in range(n_iter): if video is not None: - prefix_video_latents = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") + prefix_video_latents = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std if prefix_video_latents.shape[2] % causal_block_size != 0: From 9d35809253ef1ca9ba1bd520ba03ffd8a161e159 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 21:45:13 +0300 Subject: [PATCH 147/264] style --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 3 ++- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 60477ff8d0ca..ad75922ca1cd 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -60,11 +60,12 @@ >>> # Load the pipeline >>> vae = AutoencoderKLWan.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", - ... torch_dtype=torch.float32, ... subfolder="vae", + ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 2ecb15505a11..1f17cc665b2f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -63,11 +63,12 @@ >>> # Load the pipeline >>> vae = AutoencoderKLWan.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", - ... torch_dtype=torch.float32, ... subfolder="vae", + ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V From 0f915f66b3b5d691b8db447bd5a04b47b6891ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 21:52:18 +0300 Subject: [PATCH 148/264] Add example usage of PIL for image input in `SkyReelsV2DiffusionForcingImageToVideoPipeline` documentation. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 1f17cc665b2f..f0b596825f0d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -59,6 +59,7 @@ ... AutoencoderKLWan, ... ) >>> from diffusers.utils import export_to_video + >>> from PIL import Image >>> # Load the pipeline >>> vae = AutoencoderKLWan.from_pretrained( @@ -77,8 +78,10 @@ >>> pipe.transformer.set_ar_attention(causal_block_size=5) >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> image = Image.open("path/to/image.png") >>> output = pipe( + ... image=image, ... prompt=prompt, ... num_inference_steps=30, ... height=544, From 9a6746bbc15d179a732986149e1c2e77371993f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 28 May 2025 21:52:34 +0300 Subject: [PATCH 149/264] Refactor `SkyReelsV2DiffusionForcingPipeline` to `SkyReelsV2DiffusionForcingVideoToVideoPipeline`, enhancing support for Video-to-Video (v2v) generation. Introduce video input handling, update latent preparation logic, and improve error handling for input parameters. --- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 157 +++++++++++++++--- 1 file changed, 131 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index e74b2c32ae9b..f2b0b32a1572 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -13,6 +13,7 @@ # limitations under the License. import html +import inspect import math import re from copy import deepcopy @@ -20,6 +21,7 @@ import ftfy import torch +from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -51,35 +53,39 @@ ```py >>> import torch >>> from diffusers import ( - ... SkyReelsV2DiffusionForcingPipeline, + ... SkyReelsV2DiffusionForcingVideoToVideoPipeline, ... FlowMatchUniPCMultistepScheduler, ... AutoencoderKLWan, ... ) >>> from diffusers.utils import export_to_video + >>> from PIL import Image >>> # Load the pipeline >>> vae = AutoencoderKLWan.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", - ... torch_dtype=torch.float32, ... subfolder="vae", + ... torch_dtype=torch.float32, ... ) - >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> shift = 8.0 >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> video = Image.open("path/to/video.mp4") >>> output = pipe( + ... video=video, ... prompt=prompt, ... num_inference_steps=30, ... height=544, ... width=960, - ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... guidance_scale=6.0, ... num_frames=97, ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos @@ -107,6 +113,66 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -121,9 +187,9 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): """ - Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. + Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a specific device, etc.). @@ -298,6 +364,8 @@ def check_inputs( negative_prompt, height, width, + video=None, + latents=None, prompt_embeds=None, negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, @@ -336,46 +404,72 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + if video is not None and latents is not None: + raise ValueError("Only one of `video` or `latents` should be provided") + if num_frames > base_num_frames and overlap_history is None: raise ValueError( "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. " "Please specify a value for `overlap_history`. Recommended values are 17 or 37." ) - raise ValueError( - 'You are supposed to specify the "overlap_history" to support the long video generation. 17 and 37 are recommanded to set.' - ) - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, - batch_size: int, + video: Optional[torch.Tensor] = None, + batch_size: int = 1, num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames: int = 81, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + generator: Optional[torch.Generator] = None, latents: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if latents is not None: - return latents.to(device=device, dtype=dtype) + timestep: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = ( + (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + ) shape = ( batch_size, num_channels_latents, num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + + if latents is None: + if isinstance(generator, list): + init_latents = [ + retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) + ] + else: + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + init_latents = (init_latents - latents_mean) * latents_std + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if hasattr(self.scheduler, "add_noise"): + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + latents = self.scheduelr.scale_noise(init_latents, timestep, noise) + else: + latents = latents.to(device) + return latents def generate_timestep_matrix( @@ -483,7 +577,8 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + video: Union[Image.Image, List[Image.Image]] = None, + prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 480, width: int = 832, @@ -515,6 +610,8 @@ def __call__( The call function to the pipeline for generation. Args: + video (`Image.Image` or `List[Image.Image]`, *optional*): + The video to guide the generation. If not defined, one has to pass `latents`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -611,6 +708,8 @@ def __call__( negative_prompt, height, width, + video, + latents, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, @@ -692,9 +791,15 @@ def __call__( num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( + video, batch_size * num_videos_per_prompt, num_channels_latents, height, From b8799635cd1e94bf6f69ea61e827bb6a90ee2444 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 09:11:33 +0300 Subject: [PATCH 150/264] Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` by removing the `image_encoder` and `image_processor` dependencies. Update the CPU offload sequence accordingly. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 25 ++----------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index f0b596825f0d..ad6ab3788442 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -21,7 +21,7 @@ import ftfy import PIL import torch -from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from diffusers.image_processor import PipelineImageInput from diffusers.utils.torch_utils import randn_tensor @@ -142,11 +142,6 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraL text_encoder ([`UMT5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - image_encoder ([`CLIPVisionModel`]): - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically - the - [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) - variant. transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): @@ -155,15 +150,13 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraL Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - image_encoder: CLIPVisionModel, - image_processor: CLIPImageProcessor, transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, @@ -174,16 +167,13 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - image_encoder=image_encoder, transformer=transformer, scheduler=scheduler, - image_processor=image_processor, ) self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - self.image_processor = image_processor # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -227,17 +217,6 @@ def _get_t5_prompt_embeds( return prompt_embeds - # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image - def encode_image( - self, - image: PipelineImageInput, - device: Optional[torch.device] = None, - ): - device = device or self._execution_device - image = self.image_processor(images=image, return_tensors="pt").to(device) - image_embeds = self.image_encoder(**image, output_hidden_states=True) - return image_embeds.hidden_states[-2] - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, From 7f3589479c105cbb224d1708dd882fa3299ba3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 16:21:12 +0300 Subject: [PATCH 151/264] Refactor `SkyReelsV2DiffusionForcingImageToVideoPipeline` to enhance latent preparation logic and condition handling. Update image input type to `Optional`, streamline video condition processing, and improve handling of `last_image` during latent generation. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 94 ++++++++++--------- 1 file changed, 51 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index ad6ab3788442..75f0d96babde 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -364,7 +364,7 @@ def check_inputs( def prepare_latents( self, - image: PipelineImageInput, + image: Optional[PipelineImageInput], batch_size: int, num_channels_latents: int = 16, height: int = 480, @@ -375,6 +375,7 @@ def prepare_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, last_image: Optional[torch.Tensor] = None, + condition: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial @@ -392,37 +393,38 @@ def prepare_latents( else: latents = latents.to(device=device, dtype=dtype) - image = image.unsqueeze(2) - if last_image is not None: - last_image = last_image.unsqueeze(2) - video_condition = torch.cat([image, last_image], dim=2) - else: - video_condition = image + if image is not None: + image = image.unsqueeze(2) + if last_image is not None: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat([image, last_image], dim=2) + else: + video_condition = image - video_condition = video_condition.to(device=device, dtype=self.vae.dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator - ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) - latent_condition = latent_condition.to(dtype) - latent_condition = (latent_condition - latents_mean) * latents_std + latent_condition = latent_condition.to(dtype) + condition = (latent_condition - latents_mean) * latents_std - return latents, latent_condition + return latents, condition # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix def generate_timestep_matrix( @@ -718,7 +720,6 @@ def __call__( device=device, ) - # Encode image embedding transformer_dtype = self.transformer.dtype prompt_embeds = prompt_embeds.to(transformer_dtype) if negative_prompt_embeds is not None: @@ -762,7 +763,9 @@ def __call__( latents, last_image, ) - prefix_video_latents_length = condition.shape[2] // 2 + prefix_video_latents_length = condition.shape[2] + if last_image is not None: + prefix_video_latents_length = prefix_video_latents_length // 2 latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( transformer_dtype @@ -896,7 +899,7 @@ def __call__( f"However, it may slightly affect the continuity of the generated video at the truncation boundary." ) prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] - prefix_video_latents_length = prefix_video_latents.shape[2] + condition = prefix_video_latents finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num @@ -906,16 +909,17 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim - image = self.video_processor.preprocess(image, height=height, width=width).to( - device, dtype=torch.float32 - ) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + if i == 0: + image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=torch.float32 ) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) latents, condition = self.prepare_latents( - image, + image if i == 0 else None, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -924,18 +928,22 @@ def __call__( torch.float32, device, generator, - None if i > 0 else latents, - last_image, + latents if i == 0 else None, + last_image if i == 0 else None, + condition if i != 0 else None, ) - prefix_video_latents_length = condition.shape[2] // 2 + prefix_video_latents_length = condition.shape[2] + if i == 0 and last_image is not None: + end_video_latents = condition[:, :, prefix_video_latents_length:, :, :] + prefix_video_latents_length = prefix_video_latents_length // 2 latents[:, :, :prefix_video_latents_length, :, :] = condition[ :, :, :prefix_video_latents_length, :, : ].to(transformer_dtype) - if last_image is not None and i == n_iter - 1: + if last_image is not None and i + 1 == n_iter: latents = torch.cat( - [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 + [latents, end_video_latents.to(transformer_dtype)], dim=2 ) base_num_frames_iter += prefix_video_latents_length @@ -954,7 +962,7 @@ def __call__( causal_block_size, ) - if last_image is not None and i == n_iter - 1: + if last_image is not None and i + 1 == n_iter: step_matrix[:, -prefix_video_latents_length:] = 0 step_update_mask[:, -prefix_video_latents_length:] = False @@ -1039,7 +1047,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if last_image is not None and i == n_iter - 1: + if last_image is not None and i + 1 == n_iter: latents = latents[:, :, :-prefix_video_latents_length] if not output_type == "latent": From a97d4d80226ba016a3157804ad88ae97c23654e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 16:21:25 +0300 Subject: [PATCH 152/264] Enhance `SkyReelsV2DiffusionForcingPipeline` by refining latent preparation for long video generation. Introduce new parameters for video handling, overlap history, and causal block size. Update logic to accommodate both short and long video scenarios, ensuring compatibility and improved processing. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 158 ++++++++++-------- 1 file changed, 84 insertions(+), 74 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index ad75922ca1cd..e13c421a537d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -343,23 +343,63 @@ def check_inputs( "Please specify a value for `overlap_history`. Recommended values are 17 or 37." ) - # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, batch_size: int, num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames: int = 81, + num_frames: Optional[int] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + transformer_dtype: Optional[torch.dtype] = None, + base_num_frames: Optional[int] = None, + video: Optional[torch.Tensor] = None, + overlap_history: Optional[int] = None, + causal_block_size: Optional[int] = None, + overlap_history_frames: Optional[int] = None, + long_video_iter: Optional[int] = None, ) -> torch.Tensor: if latents is not None: return latents.to(device=device, dtype=dtype) - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if video is not None: # long video generation at the iterations other than the first one + prefix_video_latents = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] + + finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = num_latent_frames - finished_frame_num + num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + elif base_num_frames is not None: # long video generation at the first iteration + num_latent_frames = base_num_frames + prefix_video_latents_length = 0 + else: # short video generation + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + prefix_video_latents_length = 0 + shape = ( batch_size, num_channels_latents, @@ -374,7 +414,11 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - return latents + + if prefix_video_latents_length > 0: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) + + return latents, num_latent_frames, prefix_video_latents_length def generate_timestep_matrix( self, @@ -665,14 +709,6 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - prefix_video_latents_length = 0 - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = ( - (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 - if base_num_frames is not None - else num_latent_frames - ) - if causal_block_size is None: causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] @@ -680,19 +716,9 @@ def __call__( if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: # Short video generation - # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [] - for _ in range(num_latent_frames): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - sample_schedulers.append(sample_scheduler) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size - ) - # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( + latents, num_latent_frames, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -704,6 +730,16 @@ def __call__( latents, ) + # 4. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + sample_schedulers.append(sample_scheduler) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + num_latent_frames, timesteps, num_latent_frames, ar_step, prefix_video_latents_length, causal_block_size + ) + # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) @@ -782,75 +818,49 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - else: # Long video generation overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None - prefix_video_latents = None - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, self.vae.dtype - ) - for i in range(n_iter): - if video is not None: - prefix_video_latents = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) - prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std - - if prefix_video_latents.shape[2] % causal_block_size != 0: - truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size - logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " - f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " - f"However, it may slightly affect the continuity of the generated video at the truncation boundary." - ) - prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] - prefix_video_latents_length = prefix_video_latents.shape[2] - - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = num_latent_frames - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: - base_num_frames_iter = base_num_frames + for long_video_iter in range(n_iter): + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents, num_latent_frames, prefix_video_latents_length = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents if long_video_iter == 0 else None, + transformer_dtype=transformer_dtype, + video=video, + overlap_history=overlap_history, + base_num_frames=base_num_frames, + causal_block_size=causal_block_size, + overlap_history_frames=overlap_history_frames, + long_video_iter=long_video_iter, + ) # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] - for _ in range(base_num_frames_iter): + for _ in range(num_latent_frames): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, + num_latent_frames, timesteps, - base_num_frames_iter, + num_latent_frames, ar_step, prefix_video_latents_length, causal_block_size, ) - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, - torch.float32, - device, - generator, - None if i > 0 else latents, - ) - if prefix_video_latents is not None: - latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) - # 6. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) From e2bfbfac7412b276bdf9b9ad23858df296ea0f08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:00:19 +0300 Subject: [PATCH 153/264] refactor --- .../pipeline_skyreels_v2_diffusion_forcing.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index e13c421a537d..3f205f318e10 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -349,12 +349,10 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames: Optional[int] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - transformer_dtype: Optional[torch.dtype] = None, base_num_frames: Optional[int] = None, video: Optional[torch.Tensor] = None, overlap_history: Optional[int] = None, @@ -415,10 +413,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if prefix_video_latents_length > 0: - latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) - - return latents, num_latent_frames, prefix_video_latents_length + return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length def generate_timestep_matrix( self, @@ -828,7 +823,7 @@ def __call__( for long_video_iter in range(n_iter): # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, num_latent_frames, prefix_video_latents_length = self.prepare_latents( + latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -846,6 +841,9 @@ def __call__( long_video_iter=long_video_iter, ) + if prefix_video_latents_length > 0: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) + # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): From 594082ee684e014bb1b90a06dedec9788e6a84d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:07:26 +0300 Subject: [PATCH 154/264] fix num_frames --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 3f205f318e10..250f0ef1d8c1 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -349,6 +349,7 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, + num_frames: int = 97, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -713,7 +714,7 @@ def __call__( # Short video generation # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, num_latent_frames, prefix_video_latents_length = self.prepare_latents( + latents, num_latent_frames, _, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -828,11 +829,11 @@ def __call__( num_channels_latents, height, width, + num_frames, torch.float32, device, generator, latents if long_video_iter == 0 else None, - transformer_dtype=transformer_dtype, video=video, overlap_history=overlap_history, base_num_frames=base_num_frames, From c4c9c0a673b1ad1a33e07d8bc15d4d3b0f69993a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:12:52 +0300 Subject: [PATCH 155/264] fix prefix_video_latents --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 250f0ef1d8c1..7859ed46a66a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -392,12 +392,14 @@ def prepare_latents( finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) - elif base_num_frames is not None: # long video generation at the first iteration + elif base_num_frames is not None: # long video generation at the first iteration num_latent_frames = base_num_frames prefix_video_latents_length = 0 + prefix_video_latents = None else: # short video generation num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 prefix_video_latents_length = 0 + prefix_video_latents = None shape = ( batch_size, From 79960dee985ae188dff3e6654c45ce4c5d0c7d7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:32:20 +0300 Subject: [PATCH 156/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 7859ed46a66a..f2b5c1afb9f6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -364,6 +364,9 @@ def prepare_latents( if latents is not None: return latents.to(device=device, dtype=dtype) + prefix_video_latents = None + prefix_video_latents_length = 0 + if video is not None: # long video generation at the iterations other than the first one prefix_video_latents = retrieve_latents( self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" @@ -394,12 +397,8 @@ def prepare_latents( num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) elif base_num_frames is not None: # long video generation at the first iteration num_latent_frames = base_num_frames - prefix_video_latents_length = 0 - prefix_video_latents = None else: # short video generation num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - prefix_video_latents_length = 0 - prefix_video_latents = None shape = ( batch_size, From f6cd857a3ff2bc0dd81ae5f05e542b31b027e5a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:32:31 +0300 Subject: [PATCH 157/264] refactor --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 110 ++++++++++-------- 1 file changed, 61 insertions(+), 49 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 75f0d96babde..766c6c0afbde 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -369,18 +369,57 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, - num_frames: int = 81, + num_frames: int = 97, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - last_image: Optional[torch.Tensor] = None, - condition: Optional[torch.Tensor] = None, + base_num_frames: Optional[int] = None, + video: Optional[torch.Tensor] = None, + overlap_history: Optional[int] = None, + causal_block_size: Optional[int] = None, + overlap_history_frames: Optional[int] = None, + long_video_iter: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial + prefix_video_latents_length = 0 + + if video is not None: # long video generation at the iterations other than the first one + condition = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + condition = (condition - latents_mean) * latents_std + + if condition.shape[2] % causal_block_size != 0: + truncate_len_latents = condition.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + condition = condition[:, :, :-truncate_len_latents] + prefix_video_latents_length = condition.shape[2] + + finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = num_latent_frames - finished_frame_num + num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + elif base_num_frames is not None: # long video generation at the first iteration + num_latent_frames = base_num_frames + else: # short video generation + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -423,8 +462,9 @@ def prepare_latents( latent_condition = latent_condition.to(dtype) condition = (latent_condition - latents_mean) * latents_std + prefix_video_latents_length = condition.shape[2] - return latents, condition + return latents, num_latent_frames, condition, prefix_video_latents_length # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix def generate_timestep_matrix( @@ -729,13 +769,6 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = ( - (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 - if base_num_frames is not None - else num_latent_frames - ) - if causal_block_size is None: causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] @@ -750,7 +783,7 @@ def __call__( last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( device, dtype=torch.float32 ) - latents, condition = self.prepare_latents( + latents, num_latent_frames, condition, prefix_video_latents_length = self.prepare_latents( image, batch_size * num_videos_per_prompt, num_channels_latents, @@ -763,7 +796,7 @@ def __call__( latents, last_image, ) - prefix_video_latents_length = condition.shape[2] + if last_image is not None: prefix_video_latents_length = prefix_video_latents_length // 2 @@ -875,7 +908,6 @@ def __call__( overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None - prefix_video_latents = None latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -884,29 +916,7 @@ def __call__( latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( device, self.vae.dtype ) - for i in range(n_iter): - if video is not None: - prefix_video_latents = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) - prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std - - if prefix_video_latents.shape[2] % causal_block_size != 0: - truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size - logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " - f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " - f"However, it may slightly affect the continuity of the generated video at the truncation boundary." - ) - prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] - condition = prefix_video_latents - - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = num_latent_frames - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: - base_num_frames_iter = base_num_frames - + for long_video_iter in range(n_iter): # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim if i == 0: @@ -918,30 +928,32 @@ def __call__( device, dtype=torch.float32 ) - latents, condition = self.prepare_latents( - image if i == 0 else None, + latents, num_latent_frames, condition, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, width, - (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, + num_frames, torch.float32, device, generator, - latents if i == 0 else None, - last_image if i == 0 else None, - condition if i != 0 else None, + latents if long_video_iter == 0 else None, + video=video, + overlap_history=overlap_history, + base_num_frames=base_num_frames, + causal_block_size=causal_block_size, + overlap_history_frames=overlap_history_frames, + long_video_iter=long_video_iter, ) - prefix_video_latents_length = condition.shape[2] - if i == 0 and last_image is not None: + if long_video_iter == 0 and last_image is not None: + prefix_video_latents_length = condition.shape[2] // 2 end_video_latents = condition[:, :, prefix_video_latents_length:, :, :] - prefix_video_latents_length = prefix_video_latents_length // 2 latents[:, :, :prefix_video_latents_length, :, :] = condition[ :, :, :prefix_video_latents_length, :, : ].to(transformer_dtype) - if last_image is not None and i + 1 == n_iter: + if last_image is not None and long_video_iter + 1 == n_iter: latents = torch.cat( [latents, end_video_latents.to(transformer_dtype)], dim=2 ) @@ -962,7 +974,7 @@ def __call__( causal_block_size, ) - if last_image is not None and i + 1 == n_iter: + if last_image is not None and long_video_iter + 1 == n_iter: step_matrix[:, -prefix_video_latents_length:] = 0 step_update_mask[:, -prefix_video_latents_length:] = False @@ -1047,7 +1059,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - if last_image is not None and i + 1 == n_iter: + if last_image is not None and long_video_iter + 1 == n_iter: latents = latents[:, :, :-prefix_video_latents_length] if not output_type == "latent": From 3ce9b05c106e3b6ca4a79ae27e9861ab032d37dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 18:32:42 +0300 Subject: [PATCH 158/264] Fix typo in scheduler method call within `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to ensure proper noise scaling during latent generation. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index f2b0b32a1572..80aebc4f87eb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -466,7 +466,7 @@ def prepare_latents( if hasattr(self.scheduler, "add_noise"): latents = self.scheduler.add_noise(init_latents, noise, timestep) else: - latents = self.scheduelr.scale_noise(init_latents, timestep, noise) + latents = self.scheduler.scale_noise(init_latents, timestep, noise) else: latents = latents.to(device) From f1b8508d9336b0c79c9494efb9b5a183eacc7608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 19:57:50 +0300 Subject: [PATCH 159/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index f2b5c1afb9f6..209aa0c365c3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -727,6 +727,8 @@ def __call__( latents, ) + base_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames + # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] for _ in range(num_latent_frames): @@ -734,7 +736,7 @@ def __call__( sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, num_latent_frames, ar_step, prefix_video_latents_length, causal_block_size + num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size ) # 6. Denoising loop From aad0feba40a369f1740c650f36101c77b258a91f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 29 May 2025 19:58:34 +0300 Subject: [PATCH 160/264] Enhance `SkyReelsV2DiffusionForcingImageToVideoPipeline` by adding support for `last_image` parameter and refining latent frame calculations. Update preprocessing logic. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 766c6c0afbde..e838dcbc1af9 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -374,6 +374,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, base_num_frames: Optional[int] = None, video: Optional[torch.Tensor] = None, overlap_history: Optional[int] = None, @@ -803,6 +804,7 @@ def __call__( latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( transformer_dtype ) + base_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames if last_image is not None: latents = torch.cat( [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 @@ -905,6 +907,8 @@ def __call__( else: # Long video generation + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None @@ -916,17 +920,16 @@ def __call__( latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( device, self.vae.dtype ) + image = self.video_processor.preprocess(image, height=height, width=width).to( + device, dtype=torch.float32 + ) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) for long_video_iter in range(n_iter): # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim - if i == 0: - image = self.video_processor.preprocess(image, height=height, width=width).to( - device, dtype=torch.float32 - ) - if last_image is not None: - last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( - device, dtype=torch.float32 - ) latents, num_latent_frames, condition, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, From 09586472d01d4f1e55180bec640d035bd2bb99a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 09:01:51 +0300 Subject: [PATCH 161/264] add statistics --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 209aa0c365c3..283673126b51 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -946,6 +946,14 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents, return_dict=False)[0] if video is None: From fcfc7f492ef6ec40cd0a48f9dbd172731fa052f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 09:03:06 +0300 Subject: [PATCH 162/264] Refine latent frame handling in `SkyReelsV2DiffusionForcingImageToVideoPipeline` by correcting variable names and reintroducing latent mean and standard deviation calculations. Update logic for frame preparation and sampling to ensure accurate video generation. --- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index e838dcbc1af9..6ddb91c3bb92 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -415,7 +415,7 @@ def prepare_latents( finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num - num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) elif base_num_frames is not None: # long video generation at the first iteration num_latent_frames = base_num_frames else: # short video generation @@ -912,14 +912,6 @@ def __call__( overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, self.vae.dtype - ) image = self.video_processor.preprocess(image, height=height, width=width).to( device, dtype=torch.float32 ) @@ -932,6 +924,7 @@ def __call__( num_channels_latents = self.vae.config.z_dim latents, num_latent_frames, condition, prefix_video_latents_length = self.prepare_latents( + image if long_video_iter == 0 else None, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -941,9 +934,10 @@ def __call__( device, generator, latents if long_video_iter == 0 else None, + last_image, + base_num_frames=base_num_frames, video=video, overlap_history=overlap_history, - base_num_frames=base_num_frames, causal_block_size=causal_block_size, overlap_history_frames=overlap_history_frames, long_video_iter=long_video_iter, @@ -960,18 +954,18 @@ def __call__( latents = torch.cat( [latents, end_video_latents.to(transformer_dtype)], dim=2 ) - base_num_frames_iter += prefix_video_latents_length + num_latent_frames += prefix_video_latents_length # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] - for _ in range(base_num_frames_iter): + for _ in range(num_latent_frames): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, + num_latent_frames, timesteps, - base_num_frames_iter, + num_latent_frames, ar_step, prefix_video_latents_length, causal_block_size, @@ -1067,6 +1061,14 @@ def __call__( if not output_type == "latent": latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents, return_dict=False)[0] if video is None: From b197ffb1eb419a679ff0deef0558f98a134a6066 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 09:33:57 +0300 Subject: [PATCH 163/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 283673126b51..64b43112f2b5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -364,6 +364,10 @@ def prepare_latents( if latents is not None: return latents.to(device=device, dtype=dtype) + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + prefix_video_latents = None prefix_video_latents_length = 0 @@ -394,7 +398,7 @@ def prepare_latents( finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames left_frame_num = num_latent_frames - finished_frame_num - num_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) elif base_num_frames is not None: # long video generation at the first iteration num_latent_frames = base_num_frames else: # short video generation @@ -404,8 +408,8 @@ def prepare_latents( batch_size, num_channels_latents, num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, + latent_height, + latent_width, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( From 54f1aa5b3a316ccff34f3b11b122a2a7bff9c9b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 12:23:59 +0300 Subject: [PATCH 164/264] refactor --- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 360 +++++++----------- 1 file changed, 128 insertions(+), 232 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 80aebc4f87eb..e27af570ad05 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -21,7 +21,6 @@ import ftfy import torch -from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -58,7 +57,6 @@ ... AutoencoderKLWan, ... ) >>> from diffusers.utils import export_to_video - >>> from PIL import Image >>> # Load the pipeline >>> vae = AutoencoderKLWan.from_pretrained( @@ -71,21 +69,19 @@ ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 + >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." - >>> video = Image.open("path/to/video.mp4") >>> output = pipe( - ... video=video, ... prompt=prompt, ... num_inference_steps=30, ... height=544, ... width=960, - ... guidance_scale=6.0, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V ... num_frames=97, ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos @@ -420,57 +416,94 @@ def prepare_latents( num_channels_latents: int = 16, height: int = 480, width: int = 832, + num_frames: int = 97, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None, - ): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + base_num_frames: Optional[int] = None, + overlap_history: Optional[int] = None, + causal_block_size: Optional[int] = None, + overlap_history_frames: Optional[int] = None, + long_video_iter: Optional[int] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) num_latent_frames = ( (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) ) + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + prefix_video_latents = None + prefix_video_latents_length = 0 + + if video is not None: # long video generation at the iterations other than the first one + prefix_video_latents = retrieve_latents( + self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" + ) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std + + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." + ) + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] + + finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = num_latent_frames - finished_frame_num + num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + elif base_num_frames is not None: # long video generation at the first iteration + num_latent_frames = base_num_frames + shape = ( batch_size, num_channels_latents, num_latent_frames, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial, + latent_height, + latent_width, ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - if latents is None: - if isinstance(generator, list): - init_latents = [ - retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size) - ] - else: - init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video] + init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] - init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = torch.cat(init_latents, dim=0).to(dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, dtype - ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, dtype + ) - init_latents = (init_latents - latents_mean) * latents_std + init_latents = (init_latents - latents_mean) * latents_std - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if hasattr(self.scheduler, "add_noise"): - latents = self.scheduler.add_noise(init_latents, noise, timestep) - else: - latents = self.scheduler.scale_noise(init_latents, timestep, noise) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if hasattr(self.scheduler, "add_noise"): + latents = self.scheduler.add_noise(init_latents, latents, timestep) else: - latents = latents.to(device) + latents = self.scheduler.scale_noise(init_latents, timestep, latents) - return latents + return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length def generate_timestep_matrix( self, @@ -577,11 +610,10 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - video: Union[Image.Image, List[Image.Image]] = None, - prompt: Union[str, List[str]] = None, + prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, + height: int = 544, + width: int = 960, num_frames: int = 97, num_inference_steps: int = 50, guidance_scale: float = 6.0, @@ -610,8 +642,6 @@ def __call__( The call function to the pipeline for generation. Args: - video (`Image.Image` or `List[Image.Image]`, *optional*): - The video to guide the generation. If not defined, one has to pass `latents`. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -619,9 +649,9 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, defaults to `480`): + height (`int`, defaults to `544`): The height of the generated video. - width (`int`, defaults to `832`): + width (`int`, defaults to `960`): The width of the generated video. num_frames (`int`, defaults to `97`): The number of frames in the generated video. @@ -669,7 +699,7 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) + Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. @@ -708,8 +738,6 @@ def __call__( negative_prompt, height, width, - video, - latents, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, @@ -766,40 +794,25 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps - prefix_video_latents_length = 0 + if causal_block_size is None: + causal_block_size = self.transformer.config.num_frame_per_block + fps_embeds = [fps] * prompt_embeds.shape[0] + fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] + + # Long video generation + overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 base_num_frames = ( (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames ) - - if causal_block_size is None: - causal_block_size = self.transformer.config.num_frame_per_block - fps_embeds = [fps] * prompt_embeds.shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # Short video generation - # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [] - for _ in range(num_latent_frames): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - sample_schedulers.append(sample_scheduler) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size - ) - - if latents is None: - video = self.video_processor.preprocess_video(video, height=height, width=width).to( - device, dtype=torch.float32 - ) - + n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 + video = None + for long_video_iter in range(n_iter): # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - video, + latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -808,7 +821,31 @@ def __call__( torch.float32, device, generator, - latents, + latents if long_video_iter == 0 else None, + video=video, + overlap_history=overlap_history, + base_num_frames=base_num_frames, + causal_block_size=causal_block_size, + overlap_history_frames=overlap_history_frames, + long_video_iter=long_video_iter, + ) + + if prefix_video_latents_length > 0: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) + + # 4. Prepare sample schedulers and timestep matrix + sample_schedulers = [] + for _ in range(num_latent_frames): + sample_scheduler = deepcopy(self.scheduler) + sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + sample_schedulers.append(sample_scheduler) + step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( + num_latent_frames, + timesteps, + num_latent_frames, + ar_step, + prefix_video_latents_length, + causal_block_size, ) # 6. Denoising loop @@ -890,165 +927,24 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - else: - # Long video generation - overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 - n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - video = None - prefix_video_latents = None - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, self.vae.dtype - ) - for i in range(n_iter): - if video is not None: - prefix_video_latents = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) - prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std - - if prefix_video_latents.shape[2] % causal_block_size != 0: - truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size - logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " - f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " - f"However, it may slightly affect the continuity of the generated video at the truncation boundary." - ) - prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] - prefix_video_latents_length = prefix_video_latents.shape[2] - - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = num_latent_frames - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: - base_num_frames_iter = base_num_frames - - # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - sample_schedulers.append(sample_scheduler) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - timesteps, - base_num_frames_iter, - ar_step, - prefix_video_latents_length, - causal_block_size, - ) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - (base_num_frames_iter - 1) * self.vae_scale_factor_temporal + 1, - torch.float32, - device, - generator, - None if i > 0 else latents, + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) ) - if prefix_video_latents is not None: - latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) - - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(step_matrix) - - with self.progress_bar(total=len(step_matrix)) as progress_bar: - for i, t in enumerate(step_matrix): - if self.interrupt: - continue - - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() - ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - - if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] - ) - * noise_factor - ) - timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - flag_df=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], - return_dict=False, - generator=generator, - )[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - - # call the callback, if provided - if i == len(step_matrix) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents = latents / latents_std + latents_mean - videos = self.vae.decode(latents, return_dict=False)[0] - if video is None: - video = videos - else: - video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents, return_dict=False)[0] + if video is None: + video = videos else: - video = latents + video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + else: + video = latents self._current_timestep = None From 0a457933759f93191221cb379df0f78dc253446a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 12:24:12 +0300 Subject: [PATCH 165/264] up --- .../pipeline_skyreels_v2_diffusion_forcing.py | 12 ++++++++-- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 24 ++++++++++--------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 64b43112f2b5..a2893aaadb1a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -731,7 +731,11 @@ def __call__( latents, ) - base_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames + base_num_frames = ( + (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) # 4. Prepare sample schedulers and timestep matrix sample_schedulers = [] @@ -825,7 +829,11 @@ def __call__( # Long video generation overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None for long_video_iter in range(n_iter): diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 6ddb91c3bb92..398ce6f1efc2 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -389,9 +389,7 @@ def prepare_latents( prefix_video_latents_length = 0 if video is not None: # long video generation at the iterations other than the first one - condition = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) + condition = retrieve_latents(self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax") latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -804,7 +802,11 @@ def __call__( latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( transformer_dtype ) - base_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames + base_num_frames = ( + (num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) if last_image is not None: latents = torch.cat( [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 @@ -908,13 +910,15 @@ def __call__( else: # Long video generation num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 if base_num_frames is not None else num_latent_frames + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None - image = self.video_processor.preprocess(image, height=height, width=width).to( - device, dtype=torch.float32 - ) + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) if last_image is not None: last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( device, dtype=torch.float32 @@ -951,9 +955,7 @@ def __call__( :, :, :prefix_video_latents_length, :, : ].to(transformer_dtype) if last_image is not None and long_video_iter + 1 == n_iter: - latents = torch.cat( - [latents, end_video_latents.to(transformer_dtype)], dim=2 - ) + latents = torch.cat([latents, end_video_latents.to(transformer_dtype)], dim=2) num_latent_frames += prefix_video_latents_length # 4. Prepare sample schedulers and timestep matrix From 37649b2e257356173991d3bcb2dabb9381743465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 16:16:19 +0300 Subject: [PATCH 166/264] Refactor `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to improve latent handling by enforcing tensor input for video, updating frame preparation logic, and adjusting default frame count. Enhance preprocessing and postprocessing steps for better integration. --- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 116 +++++++----------- 1 file changed, 46 insertions(+), 70 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index e27af570ad05..872a0aac5692 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -18,6 +18,7 @@ import re from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union +from PIL import Image import ftfy import torch @@ -411,7 +412,7 @@ def check_inputs( def prepare_latents( self, - video: Optional[torch.Tensor] = None, + video: torch.Tensor, batch_size: int = 1, num_channels_latents: int = 16, height: int = 480, @@ -437,39 +438,32 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - prefix_video_latents = None - prefix_video_latents_length = 0 + prefix_video_latents = [retrieve_latents(self.vae.encode(vid[:, :, -overlap_history:].unsqueeze(0)), sample_mode="argmax") for vid in video] + prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype) - if video is not None: # long video generation at the iterations other than the first one - prefix_video_latents = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(device, self.vae.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + device, self.vae.dtype + ) + prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, self.vae.dtype + if prefix_video_latents.shape[2] % causal_block_size != 0: + truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size + logger.warning( + f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " + f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " + f"However, it may slightly affect the continuity of the generated video at the truncation boundary." ) - prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std - - if prefix_video_latents.shape[2] % causal_block_size != 0: - truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size - logger.warning( - f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. " - f"This truncation ensures compatibility with the causal block size, which is required for proper processing. " - f"However, it may slightly affect the continuity of the generated video at the truncation boundary." - ) - prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] - prefix_video_latents_length = prefix_video_latents.shape[2] + prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] + prefix_video_latents_length = prefix_video_latents.shape[2] - finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = num_latent_frames - finished_frame_num - num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) - elif base_num_frames is not None: # long video generation at the first iteration - num_latent_frames = base_num_frames + finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames + left_frame_num = num_latent_frames - finished_frame_num + num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) shape = ( batch_size, @@ -484,27 +478,11 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video] - - init_latents = torch.cat(init_latents, dim=0).to(dtype) - - latents_mean = ( - torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, dtype - ) - - init_latents = (init_latents - latents_mean) * latents_std - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - if hasattr(self.scheduler, "add_noise"): - latents = self.scheduler.add_noise(init_latents, latents, timestep) - else: - latents = self.scheduler.scale_noise(init_latents, timestep, latents) return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length + # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix def generate_timestep_matrix( self, num_frames: int, @@ -610,11 +588,12 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + video: List[Image.Image], + prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, height: int = 544, width: int = 960, - num_frames: int = 97, + num_frames: int = 120, num_inference_steps: int = 50, guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, @@ -642,18 +621,20 @@ def __call__( The call function to the pipeline for generation. Args: + video (`List[Image.Image]`): + The video to guide the video generation. prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass + The prompt or prompts not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). height (`int`, defaults to `544`): The height of the generated video. width (`int`, defaults to `960`): The width of the generated video. - num_frames (`int`, defaults to `97`): + num_frames (`int`, defaults to `120`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -732,12 +713,18 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, negative_prompt, height, width, + video, + latents, prompt_embeds, negative_prompt_embeds, callback_on_step_end_tensor_inputs, @@ -794,6 +781,11 @@ def __call__( self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width).to( + device, dtype=torch.float32 + ) + if causal_block_size is None: causal_block_size = self.transformer.config.num_frame_per_block fps_embeds = [fps] * prompt_embeds.shape[0] @@ -808,11 +800,11 @@ def __call__( else num_latent_frames ) n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - video = None for long_video_iter in range(n_iter): # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( + video, batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -822,7 +814,6 @@ def __call__( device, generator, latents if long_video_iter == 0 else None, - video=video, overlap_history=overlap_history, base_num_frames=base_num_frames, causal_block_size=causal_block_size, @@ -939,28 +930,13 @@ def __call__( ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents, return_dict=False)[0] - if video is None: - video = videos - else: - video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) else: video = latents self._current_timestep = None if not output_type == "latent": - if overlap_history is None: - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( - 1, self.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 46c6e722fb8c37d0905c150e2e333ae032bf1581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 30 May 2025 16:16:54 +0300 Subject: [PATCH 167/264] style --- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 872a0aac5692..88d583007ac3 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -18,10 +18,10 @@ import re from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Union -from PIL import Image import ftfy import torch +from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -438,7 +438,10 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - prefix_video_latents = [retrieve_latents(self.vae.encode(vid[:, :, -overlap_history:].unsqueeze(0)), sample_mode="argmax") for vid in video] + prefix_video_latents = [ + retrieve_latents(self.vae.encode(vid[:, :, -overlap_history:].unsqueeze(0)), sample_mode="argmax") + for vid in video + ] prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype) latents_mean = ( From 0edb263f3dee3fc967d54a7e7e4281ee36245003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 12:44:02 +0300 Subject: [PATCH 168/264] --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 398ce6f1efc2..30779ddf0bb6 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -796,8 +796,8 @@ def __call__( last_image, ) - if last_image is not None: - prefix_video_latents_length = prefix_video_latents_length // 2 + #if last_image is not None: + # prefix_video_latents_length = prefix_video_latents_length // 2 latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( transformer_dtype From 4d724df49f84f2d5888f0010b32a6a2e1b0e374d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 13:47:57 +0300 Subject: [PATCH 169/264] fix vae output indexing --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 30779ddf0bb6..8a3a194465ee 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -799,7 +799,7 @@ def __call__( #if last_image is not None: # prefix_video_latents_length = prefix_video_latents_length // 2 - latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :, :prefix_video_latents_length, :, :].to( + latents[:, :, :prefix_video_latents_length, :, :] = condition[0, :, :, :, :].to( transformer_dtype ) base_num_frames = ( @@ -809,7 +809,7 @@ def __call__( ) if last_image is not None: latents = torch.cat( - [latents, condition[:, :, prefix_video_latents_length:, :, :].to(transformer_dtype)], dim=2 + [latents, condition[1, :, :, :, :].to(transformer_dtype)], dim=2 ) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length From 79dbd0e833716968ac0dab2e670cdb82c328dd8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 14:00:20 +0300 Subject: [PATCH 170/264] upup --- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 8a3a194465ee..6855f8ce4169 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -799,7 +799,9 @@ def __call__( #if last_image is not None: # prefix_video_latents_length = prefix_video_latents_length // 2 - latents[:, :, :prefix_video_latents_length, :, :] = condition[0, :, :, :, :].to( + channel_dim = condition.shape[1] + + latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :channel_dim // 2, :, :, :].to( transformer_dtype ) base_num_frames = ( @@ -809,7 +811,7 @@ def __call__( ) if last_image is not None: latents = torch.cat( - [latents, condition[1, :, :, :, :].to(transformer_dtype)], dim=2 + [latents, condition[:, channel_dim // 2:, :, :, :].to(transformer_dtype)], dim=2 ) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length From 22c761ebecd6ac9c375b73a47250f6a996934dfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 14:05:15 +0300 Subject: [PATCH 171/264] --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 6855f8ce4169..d437b7c3d8aa 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -800,6 +800,7 @@ def __call__( # prefix_video_latents_length = prefix_video_latents_length // 2 channel_dim = condition.shape[1] + print(latents.shape, condition.shape, prefix_video_latents_length) latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :channel_dim // 2, :, :, :].to( transformer_dtype From fbf5cc1b51b5f3b9f16b2b9bfed3751565356dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 14:09:27 +0300 Subject: [PATCH 172/264] --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index d437b7c3d8aa..7b7e7a9e263a 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -812,7 +812,7 @@ def __call__( ) if last_image is not None: latents = torch.cat( - [latents, condition[:, channel_dim // 2:, :, :, :].to(transformer_dtype)], dim=2 + [latents, condition[:, channel_dim // 2:, :, :, :].to(transformer_dtype)], dim=1 ) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length From 92c4e8c45e394e440c0b5d7d050c3658600bd663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 14:11:58 +0300 Subject: [PATCH 173/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 7b7e7a9e263a..48f7e85962f7 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -802,7 +802,7 @@ def __call__( channel_dim = condition.shape[1] print(latents.shape, condition.shape, prefix_video_latents_length) - latents[:, :, :prefix_video_latents_length, :, :] = condition[:, :channel_dim // 2, :, :, :].to( + latents[:, :channel_dim // 2, :prefix_video_latents_length, :, :] = condition[:, :channel_dim // 2, :, :, :].to( transformer_dtype ) base_num_frames = ( From bb9ca6f41fc7036c14a6d777d97b898f18032710 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 14:59:31 +0300 Subject: [PATCH 174/264] Fix tensor concatenation and repetition logic in `SkyReelsV2DiffusionForcingImageToVideoPipeline` to ensure correct dimensionality for video conditions and latent conditions. --- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 48f7e85962f7..b756a24d80b4 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -435,7 +435,7 @@ def prepare_latents( image = image.unsqueeze(2) if last_image is not None: last_image = last_image.unsqueeze(2) - video_condition = torch.cat([image, last_image], dim=2) + video_condition = torch.cat([image, last_image], dim=0) else: video_condition = image @@ -457,7 +457,7 @@ def prepare_latents( latent_condition = torch.cat(latent_condition) else: latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.repeat_interleave(batch_size, dim=0) latent_condition = latent_condition.to(dtype) condition = (latent_condition - latents_mean) * latents_std @@ -799,10 +799,9 @@ def __call__( #if last_image is not None: # prefix_video_latents_length = prefix_video_latents_length // 2 - channel_dim = condition.shape[1] print(latents.shape, condition.shape, prefix_video_latents_length) - latents[:, :channel_dim // 2, :prefix_video_latents_length, :, :] = condition[:, :channel_dim // 2, :, :, :].to( + latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to( transformer_dtype ) base_num_frames = ( @@ -812,7 +811,7 @@ def __call__( ) if last_image is not None: latents = torch.cat( - [latents, condition[:, channel_dim // 2:, :, :, :].to(transformer_dtype)], dim=1 + [latents, condition[condition.shape[0]//2:].to(transformer_dtype)], dim=2 ) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length From 18e525fa48cb249a018f6561807e5fda6fd17dd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 31 May 2025 21:42:25 +0300 Subject: [PATCH 175/264] Refactor latent retrieval logic in `SkyReelsV2DiffusionForcingVideoToVideoPipeline` to handle tensor dimensions more robustly, ensuring compatibility with both 3D and 4D video inputs. --- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 88d583007ac3..b86ac568e968 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -433,13 +433,16 @@ def prepare_latents( return latents.to(device=device, dtype=dtype) num_latent_frames = ( - (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1) + (video.shape[2] - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.shape[2] ) latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial prefix_video_latents = [ - retrieve_latents(self.vae.encode(vid[:, :, -overlap_history:].unsqueeze(0)), sample_mode="argmax") + retrieve_latents( + self.vae.encode(vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:]), + sample_mode="argmax", + ) for vid in video ] prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype) From 528a81187c791616f54105add0ad974059ffca4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 1 Jun 2025 15:43:20 +0300 Subject: [PATCH 176/264] Enhance logging in `SkyReelsV2DiffusionForcing` pipelines by adding iteration print statements for better debugging. Clean up unused code related to prefix video latents length calculation in `SkyReelsV2DiffusionForcingImageToVideoPipeline`. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 1 + .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 14 +++----------- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 1 + 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a2893aaadb1a..a29359f89eff 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -837,6 +837,7 @@ def __call__( n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 video = None for long_video_iter in range(n_iter): + print(f"long_video_iter:{long_video_iter}") # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index b756a24d80b4..b0b927284820 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -796,11 +796,6 @@ def __call__( last_image, ) - #if last_image is not None: - # prefix_video_latents_length = prefix_video_latents_length // 2 - - print(latents.shape, condition.shape, prefix_video_latents_length) - latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to( transformer_dtype ) @@ -926,9 +921,9 @@ def __call__( device, dtype=torch.float32 ) for long_video_iter in range(n_iter): + print(f"long_video_iter:{long_video_iter}") # 5. Prepare latent variables num_channels_latents = self.vae.config.z_dim - latents, num_latent_frames, condition, prefix_video_latents_length = self.prepare_latents( image if long_video_iter == 0 else None, batch_size * num_videos_per_prompt, @@ -950,12 +945,9 @@ def __call__( ) if long_video_iter == 0 and last_image is not None: - prefix_video_latents_length = condition.shape[2] // 2 - end_video_latents = condition[:, :, prefix_video_latents_length:, :, :] + end_video_latents = condition[condition.shape[0]//2:] - latents[:, :, :prefix_video_latents_length, :, :] = condition[ - :, :, :prefix_video_latents_length, :, : - ].to(transformer_dtype) + latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to(transformer_dtype) if last_image is not None and long_video_iter + 1 == n_iter: latents = torch.cat([latents, end_video_latents.to(transformer_dtype)], dim=2) num_latent_frames += prefix_video_latents_length diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index b86ac568e968..9bcd69109023 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -807,6 +807,7 @@ def __call__( ) n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 for long_video_iter in range(n_iter): + print(f"long_video_iter:{long_video_iter}") # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( From 7814f7d8e8ca6fccac66dd7afd661e1b57fb1f13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 1 Jun 2025 15:50:32 +0300 Subject: [PATCH 177/264] Update latent handling in `SkyReelsV2DiffusionForcingImageToVideoPipeline` to conditionally set latents based on video iteration state, improving flexibility for video input processing. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index b0b927284820..4e1b7ae82469 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -946,8 +946,10 @@ def __call__( if long_video_iter == 0 and last_image is not None: end_video_latents = condition[condition.shape[0]//2:] + latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to(transformer_dtype) + else: + latents[:, :, :prefix_video_latents_length, :, :] = condition.to(transformer_dtype) - latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to(transformer_dtype) if last_image is not None and long_video_iter + 1 == n_iter: latents = torch.cat([latents, end_video_latents.to(transformer_dtype)], dim=2) num_latent_frames += prefix_video_latents_length From 6d1d1e916543cda9da4ff0e47b6f9a26491f538c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 14:00:17 +0300 Subject: [PATCH 178/264] Refactor `SkyReelsV2TimeTextImageEmbedding` to utilize `get_1d_sincos_pos_embed_from_grid` for timestep projection. --- .../models/transformers/transformer_skyreels_v2.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 46e58153d819..431cac24af77 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -29,6 +29,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from ..embeddings import get_1d_sincos_pos_embed_from_grid logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -161,7 +162,9 @@ def __init__( ): super().__init__() - self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + #self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_freq_dim = time_freq_dim + self.timesteps_proj = get_1d_sincos_pos_embed_from_grid self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) self.act_fn = nn.SiLU() self.time_proj = nn.Linear(dim, time_proj_dim) @@ -177,9 +180,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, ): - original_timestep_shape = timestep.shape - timestep = self.timesteps_proj(timestep.reshape(-1)) - timestep = timestep.reshape(*original_timestep_shape, -1) + timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt") time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: @@ -498,7 +499,7 @@ def forward( temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - timestep_proj = timestep_proj.unflatten(-1, (6, -1)) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) From 82d86e46b4a5c32d838dc91f200f0ec129ad8fbe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 14:53:05 +0300 Subject: [PATCH 179/264] Enhance `get_1d_sincos_pos_embed_from_grid` function to include an optional parameter `flip_sin_to_cos` for flipping sine and cosine embeddings, improving flexibility in positional embedding generation. --- src/diffusers/models/embeddings.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c25e9997e3fb..2a2e25517c15 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False): """ This function generates 1D positional embeddings from a grid. @@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): emb_cos = torch.cos(out) # (M, D/2) emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, embed_dim // 2:], emb[:, :embed_dim // 2]], dim=1) + return emb From 2dce7512c4df442e00ea7c21d21dd1abbc6c9ce9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 14:53:18 +0300 Subject: [PATCH 180/264] Update timestep projection in `SkyReelsV2TimeTextImageEmbedding` to include `flip_sin_to_cos` parameter, enhancing the flexibility of time embedding generation. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 431cac24af77..52e9e0a48cce 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -180,7 +180,7 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, ): - timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt") + timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt", flip_sin_to_cos=True) time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: From a4aa0bad67c8b03efbe4869cfd4e1467c7c471f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 16:43:13 +0300 Subject: [PATCH 181/264] Refactor tensor type handling in `SkyReelsV2AttnProcessor2_0` and `SkyReelsV2TransformerBlock` to ensure consistent use of `torch.float32` and `torch.bfloat16`, improving integration. --- .../transformers/transformer_skyreels_v2.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 52e9e0a48cce..b89c2fb0cb04 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -70,14 +70,14 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) if rotary_emb is not None: def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2))) x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) return x_out.type_as(hidden_states) @@ -110,10 +110,16 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): is_causal=False, ) else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = F.scaled_dot_product_attention( + query.to(torch.bfloat16), + key.to(torch.bfloat16), + value.to(torch.bfloat16), + dropout_p=0.0, + is_causal=False, + ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) + #hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: hidden_states = hidden_states + hidden_states_img @@ -324,7 +330,7 @@ def forward( ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - return hidden_states + return hidden_states.to(torch.bfloat16) def set_ar_attention(self): self.attn1.processor.set_ar_attention() @@ -567,7 +573,7 @@ def forward( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3).float() if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From 63814a0f89112d048f14e5697034de2706e1b8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 17:59:20 +0300 Subject: [PATCH 182/264] Update tensor type in `SkyReelsV2RotaryPosEmbed` to use `torch.float32` for frequency calculations, ensuring consistency in data types across the model. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index b89c2fb0cb04..983f5c0119fc 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -217,7 +217,7 @@ def __init__( freqs = [] for dim in [t_dim, h_dim, w_dim]: freq = get_1d_rotary_pos_embed( - dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 + dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32 ) freqs.append(freq) self.freqs = torch.cat(freqs, dim=1) From a74248f0c716cfd715f9b503a0dc0feeeb79977b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 2 Jun 2025 18:15:43 +0300 Subject: [PATCH 183/264] Refactor `SkyReelsV2TimeTextImageEmbedding` to utilize automatic mixed precision for timestep projection. --- .../transformers/transformer_skyreels_v2.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 983f5c0119fc..07f10de82f43 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -186,13 +186,14 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, ): - timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt", flip_sin_to_cos=True) - - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype - if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: - timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) - timestep_proj = self.time_proj(self.act_fn(temb)) + with torch.amp.autocast("cuda", dtype=torch.float32): + timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt", flip_sin_to_cos=True) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) if encoder_hidden_states_image is not None: From efccb9e944a9f29f153b23083467615f976c1013 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 07:31:49 +0300 Subject: [PATCH 184/264] down --- .../transformers/transformer_skyreels_v2.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 07f10de82f43..31f92ba6310b 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -168,7 +168,6 @@ def __init__( ): super().__init__() - #self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) self.time_freq_dim = time_freq_dim self.timesteps_proj = get_1d_sincos_pos_embed_from_grid self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) @@ -186,14 +185,13 @@ def forward( encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, ): - with torch.amp.autocast("cuda", dtype=torch.float32): - timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt", flip_sin_to_cos=True) - - time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype - if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: - timestep = timestep.to(time_embedder_dtype) - temb = self.time_embedder(timestep).type_as(encoder_hidden_states) - timestep_proj = self.time_proj(self.act_fn(temb)) + timestep = self.timesteps_proj(self.time_freq_dim, timestep, output_type="pt", flip_sin_to_cos=True) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) encoder_hidden_states = self.text_embedder(encoder_hidden_states) if encoder_hidden_states_image is not None: From b836618fd36037c94a887c74be462de329945d79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 08:15:12 +0300 Subject: [PATCH 185/264] down --- .../transformers/transformer_skyreels_v2.py | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 31f92ba6310b..f1415db1eb0d 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -70,9 +70,9 @@ def __call__( if attn.norm_k is not None: key = attn.norm_k(key) - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2).to(torch.float32) + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) if rotary_emb is not None: @@ -100,26 +100,17 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - if self._flag_ar_attention: hidden_states = F.scaled_dot_product_attention( - query.to(torch.bfloat16), - key.to(torch.bfloat16), - value.to(torch.bfloat16), - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=False, - ) - else: - hidden_states = F.scaled_dot_product_attention( - query.to(torch.bfloat16), - key.to(torch.bfloat16), - value.to(torch.bfloat16), + query, + key, + value, + attn_mask=attention_mask if self._flag_ar_attention else None, dropout_p=0.0, is_causal=False, ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - #hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.type_as(query) if hidden_states_img is not None: hidden_states = hidden_states + hidden_states_img @@ -329,7 +320,7 @@ def forward( ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - return hidden_states.to(torch.bfloat16) + return hidden_states def set_ar_attention(self): self.attn1.processor.set_ar_attention() @@ -572,7 +563,7 @@ def forward( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) - output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3).float() + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From 11cd6fb2398ffde314aff795cd9ae5164323a20f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 08:16:50 +0300 Subject: [PATCH 186/264] style --- src/diffusers/models/embeddings.py | 2 +- .../models/transformers/transformer_skyreels_v2.py | 8 ++++++-- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 12 ++++++------ .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 4 +++- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 2a2e25517c15..cc74f56c1356 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -355,7 +355,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin # flip sine and cosine embeddings if flip_sin_to_cos: - emb = torch.cat([emb[:, embed_dim // 2:], emb[:, :embed_dim // 2]], dim=1) + emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1) return emb diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index f1415db1eb0d..145c14a20b94 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -25,11 +25,15 @@ from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..embeddings import ( + PixArtAlphaTextProjection, + TimestepEmbedding, + get_1d_rotary_pos_embed, + get_1d_sincos_pos_embed_from_grid, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm -from ..embeddings import get_1d_sincos_pos_embed_from_grid logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 4e1b7ae82469..a33108bcd46c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -796,7 +796,7 @@ def __call__( last_image, ) - latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to( + latents[:, :, :prefix_video_latents_length, :, :] = condition[: condition.shape[0] // 2].to( transformer_dtype ) base_num_frames = ( @@ -805,9 +805,7 @@ def __call__( else num_latent_frames ) if last_image is not None: - latents = torch.cat( - [latents, condition[condition.shape[0]//2:].to(transformer_dtype)], dim=2 - ) + latents = torch.cat([latents, condition[condition.shape[0] // 2 :].to(transformer_dtype)], dim=2) base_num_frames += prefix_video_latents_length num_latent_frames += prefix_video_latents_length @@ -945,8 +943,10 @@ def __call__( ) if long_video_iter == 0 and last_image is not None: - end_video_latents = condition[condition.shape[0]//2:] - latents[:, :, :prefix_video_latents_length, :, :] = condition[:condition.shape[0]//2].to(transformer_dtype) + end_video_latents = condition[condition.shape[0] // 2 :] + latents[:, :, :prefix_video_latents_length, :, :] = condition[: condition.shape[0] // 2].to( + transformer_dtype + ) else: latents[:, :, :prefix_video_latents_length, :, :] = condition.to(transformer_dtype) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 9bcd69109023..75f43229c883 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -440,7 +440,9 @@ def prepare_latents( prefix_video_latents = [ retrieve_latents( - self.vae.encode(vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:]), + self.vae.encode( + vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:] + ), sample_mode="argmax", ) for vid in video From 786f145568be35ec6399961edec6d9b21c34f870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 09:16:04 +0300 Subject: [PATCH 187/264] Add debug tensor tracking to `SkyReelsV2Transformer3DModel` for enhanced debugging and output analysis; update `Transformer2DModelOutput` to include debug tensors. --- src/diffusers/models/modeling_outputs.py | 4 ++ .../transformers/transformer_skyreels_v2.py | 53 +++++++++++++++++-- 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 0120a34d9052..9df5cefe5dff 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Optional, Dict, Any from ..utils import BaseOutput @@ -26,6 +27,9 @@ class Transformer2DModelOutput(BaseOutput): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. + debug_tensors (`Optional[Dict[str, Any]]`, *optional*): + A dictionary containing intermediate tensors and their shapes for debugging purposes. """ sample: "torch.Tensor" # noqa: F821 + debug_tensors: Optional[Dict[str, Any]] = None diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 145c14a20b94..eded5ea4be11 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -456,6 +456,14 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + debug_tensors = {} + debug_tensors["initial_hidden_states"] = hidden_states.detach().clone() + debug_tensors["initial_hidden_states_shape"] = list(hidden_states.shape) + debug_tensors["initial_timestep"] = timestep.detach().clone() + debug_tensors["initial_timestep_shape"] = list(timestep.shape) + debug_tensors["initial_encoder_hidden_states"] = encoder_hidden_states.detach().clone() + debug_tensors["initial_encoder_hidden_states_shape"] = list(encoder_hidden_states.shape) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -478,9 +486,15 @@ def forward( post_patch_width = width // p_w rotary_emb = self.rope(hidden_states) + debug_tensors["rotary_emb"] = rotary_emb.detach().clone() + debug_tensors["rotary_emb_shape"] = list(rotary_emb.shape) hidden_states = self.patch_embedding(hidden_states) + debug_tensors["hidden_states_after_patch_embedding"] = hidden_states.detach().clone() + debug_tensors["hidden_states_after_patch_embedding_shape"] = list(hidden_states.shape) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) + debug_tensors["grid_sizes"] = grid_sizes.detach().clone() + debug_tensors["grid_sizes_shape"] = list(grid_sizes.shape) if self.config.flag_causal_attention: frame_num, height, width = grid_sizes @@ -495,10 +509,19 @@ def forward( causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) hidden_states = hidden_states.flatten(2).transpose(1, 2) + debug_tensors["hidden_states_after_flatten_transpose"] = hidden_states.detach().clone() + debug_tensors["hidden_states_after_flatten_transpose_shape"] = list(hidden_states.shape) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) + debug_tensors["temb_after_time_embedding"] = temb.detach().clone() + debug_tensors["temb_after_time_embedding_shape"] = list(temb.shape) + debug_tensors["timestep_proj_after_time_projection"] = timestep_proj.detach().clone() + debug_tensors["timestep_proj_after_time_projection_shape"] = list(timestep_proj.shape) + debug_tensors["encoder_hidden_states_after_text_embedding"] = encoder_hidden_states.detach().clone() + debug_tensors["encoder_hidden_states_after_text_embedding_shape"] = list(encoder_hidden_states.shape) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) if encoder_hidden_states_image is not None: @@ -506,7 +529,9 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - for block in self.blocks: + for i, block in enumerate(self.blocks): + debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() + debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) hidden_states = self._gradient_checkpointing_func( block, hidden_states, @@ -515,10 +540,14 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) + debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() + debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() + debug_tensors["fps_emb"] = fps_emb.detach().clone() + debug_tensors["fps_emb_shape"] = list(fps_emb.shape) timestep_proj = timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) if flag_df: @@ -527,6 +556,8 @@ def forward( ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) + debug_tensors["timestep_proj_after_fps_projection"] = timestep_proj.detach().clone() + debug_tensors["timestep_proj_after_fps_projection_shape"] = list(timestep_proj.shape) timestep_proj = timestep_proj.to(hidden_states.dtype) if flag_df: @@ -537,7 +568,9 @@ def forward( timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() - for block in self.blocks: + for i, block in enumerate(self.blocks): + debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() + debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) hidden_states = block( hidden_states, encoder_hidden_states, @@ -545,8 +578,13 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) + debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() + debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) - # 5. Output norm, projection & unpatchify + debug_tensors["temb_before_output_modulation"] = temb.detach().clone() + debug_tensors["temb_before_output_modulation_shape"] = list(temb.shape) + debug_tensors["scale_shift_table"] = self.scale_shift_table.detach().clone() + debug_tensors["scale_shift_table_shape"] = list(self.scale_shift_table.shape) if temb.dim() == 2: shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) elif temb.dim() == 3: @@ -561,13 +599,20 @@ def forward( scale = scale.to(hidden_states.device) hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + debug_tensors["hidden_states_after_output_modulation"] = hidden_states.detach().clone() + debug_tensors["hidden_states_after_output_modulation_shape"] = list(hidden_states.shape) + hidden_states = self.proj_out(hidden_states) + debug_tensors["hidden_states_after_proj_out"] = hidden_states.detach().clone() + debug_tensors["hidden_states_after_proj_out_shape"] = list(hidden_states.shape) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + debug_tensors["output_after_unpatchify"] = output.detach().clone() + debug_tensors["output_after_unpatchify_shape"] = list(output.shape) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer @@ -576,7 +621,7 @@ def forward( if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) + return Transformer2DModelOutput(sample=output, debug_tensors=debug_tensors) def set_ar_attention(self, causal_block_size): self.register_to_config(num_frame_per_block=causal_block_size, flag_causal_attention=True) From b597f9eb2e3416c5d376f35f5098690fe00ebf1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 10:21:51 +0300 Subject: [PATCH 188/264] up --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index eded5ea4be11..fd89d6664109 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -522,7 +522,7 @@ def forward( debug_tensors["encoder_hidden_states_after_text_embedding"] = encoder_hidden_states.detach().clone() debug_tensors["encoder_hidden_states_after_text_embedding_shape"] = list(encoder_hidden_states.shape) - timestep_proj = timestep_proj.unflatten(1, (6, -1)) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) From 6caffc904f0b6f592d1a1be1276bd0955669a1a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 10:31:14 +0300 Subject: [PATCH 189/264] Refactor indentation in `SkyReelsV2AttnProcessor2_0` to improve code readability and maintain consistency in style. --- .../transformers/transformer_skyreels_v2.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index fd89d6664109..a81fb8f02f2a 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -104,14 +104,14 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask=attention_mask if self._flag_ar_attention else None, - dropout_p=0.0, - is_causal=False, - ) + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask if self._flag_ar_attention else None, + dropout_p=0.0, + is_causal=False, + ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) From 848acfc890f1ce96e7cd84105ec1fa5032c3fd3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Tue, 3 Jun 2025 19:28:01 +0300 Subject: [PATCH 190/264] Convert query, key, and value tensors to bfloat16 in `SkyReelsV2AttnProcessor2_0` for improved performance. --- .../models/transformers/transformer_skyreels_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index a81fb8f02f2a..197e7ceb2a23 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -105,9 +105,9 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): hidden_states_img = hidden_states_img.type_as(query) hidden_states = F.scaled_dot_product_attention( - query, - key, - value, + query.to(torch.bfloat16), + key.to(torch.bfloat16), + value.to(torch.bfloat16), attn_mask=attention_mask if self._flag_ar_attention else None, dropout_p=0.0, is_causal=False, From a8e01ba34ca824c1832c7aa2755b5c2e4de7f922 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 4 Jun 2025 14:01:28 +0300 Subject: [PATCH 191/264] Add debug print statements in `SkyReelsV2TransformerBlock` to track tensor shapes and values for improved debugging and analysis. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 197e7ceb2a23..0956a97a4553 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -305,6 +305,9 @@ def forward( e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] + print(f"scale_msa: {scale_msa.shape}, shift_msa: {shift_msa.shape}, gate_msa: {gate_msa.shape}, c_scale_msa: {c_scale_msa.shape}, c_shift_msa: {c_shift_msa.shape}, c_gate_msa: {c_gate_msa.shape}") + print(f"scale_msa: {scale_msa}, shift_msa: {shift_msa}, gate_msa: {gate_msa}, c_scale_msa: {c_scale_msa}, c_shift_msa: {c_shift_msa}, c_gate_msa: {c_gate_msa}") + print(f"hidden_states: {hidden_states.shape}, encoder_hidden_states: {encoder_hidden_states.shape}, temb: {temb.shape}, rotary_emb: {rotary_emb.shape}") # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1( From aef60a1b89d0b43138cf6726aeeabaf2d30d56b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 4 Jun 2025 17:23:45 +0300 Subject: [PATCH 192/264] debug --- .../transformers/transformer_skyreels_v2.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 0956a97a4553..56abb8ee1a4a 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -304,17 +304,17 @@ def forward( elif temb.dim() == 4: e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] - - print(f"scale_msa: {scale_msa.shape}, shift_msa: {shift_msa.shape}, gate_msa: {gate_msa.shape}, c_scale_msa: {c_scale_msa.shape}, c_shift_msa: {c_shift_msa.shape}, c_gate_msa: {c_gate_msa.shape}") - print(f"scale_msa: {scale_msa}, shift_msa: {shift_msa}, gate_msa: {gate_msa}, c_scale_msa: {c_scale_msa}, c_shift_msa: {c_shift_msa}, c_gate_msa: {c_gate_msa}") - print(f"hidden_states: {hidden_states.shape}, encoder_hidden_states: {encoder_hidden_states.shape}, temb: {temb.shape}, rotary_emb: {rotary_emb.shape}") - # 1. Self-attention + debug_dict = {} + debug_dict['x'] = hidden_states + # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + debug_dict['mul_add_add_compile'] = norm_hidden_states attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask ) + debug_dict['self_attn'] = attn_output hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) - + debug_dict['mul_add_compile'] = hidden_states # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) @@ -326,8 +326,8 @@ def forward( ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - - return hidden_states + debug_dict['cross_attn_ffn'] = hidden_states + return hidden_states, debug_dict def set_ar_attention(self): self.attn1.processor.set_ar_attention() @@ -535,7 +535,7 @@ def forward( for i, block in enumerate(self.blocks): debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) - hidden_states = self._gradient_checkpointing_func( + hidden_states, debug_dict = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, @@ -543,6 +543,7 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) + debug_tensors[f"block_{i}_debug_dict"] = debug_dict debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) if self.config.inject_sample_info: @@ -574,13 +575,14 @@ def forward( for i, block in enumerate(self.blocks): debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) - hidden_states = block( + hidden_states, debug_dict = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) + debug_tensors[f"block_{i}_debug_dict"] = debug_dict debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) From f70627a4752cfbf8551cbe188f92e459432b322b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 4 Jun 2025 17:27:10 +0300 Subject: [PATCH 193/264] --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 56abb8ee1a4a..93fbfa202122 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -583,6 +583,7 @@ def forward( causal_mask if self.config.flag_causal_attention else None, ) debug_tensors[f"block_{i}_debug_dict"] = debug_dict + return hidden_states, debug_dict debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) From 7a98f19de22d8235da040a7d502f49a99bfc071d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 4 Jun 2025 17:58:41 +0300 Subject: [PATCH 194/264] debug --- .../models/transformers/transformer_skyreels_v2.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 93fbfa202122..f7b6f1f35c26 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -307,14 +307,16 @@ def forward( debug_dict = {} debug_dict['x'] = hidden_states # 1. Self-attention - norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) - debug_dict['mul_add_add_compile'] = norm_hidden_states + norm_x = self.norm1(hidden_states.float()) + debug_dict['norm_x'] = norm_x + norm_hidden_states = (norm_x * (1 + scale_msa) + shift_msa).type_as(hidden_states) + debug_dict['mul_add_add'] = norm_hidden_states attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask ) - debug_dict['self_attn'] = attn_output + #debug_dict['self_attn'] = attn_output hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) - debug_dict['mul_add_compile'] = hidden_states + #debug_dict['mul_add_compile'] = hidden_states # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) @@ -326,7 +328,7 @@ def forward( ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - debug_dict['cross_attn_ffn'] = hidden_states + #debug_dict['cross_attn_ffn'] = hidden_states return hidden_states, debug_dict def set_ar_attention(self): From 17e931a02c2cafbef1d010afae1fd1aade2f9c54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Wed, 4 Jun 2025 20:07:14 +0300 Subject: [PATCH 195/264] Remove commented-out debug tensor tracking from `SkyReelsV2TransformerBlock` --- src/diffusers/models/modeling_outputs.py | 2 +- .../transformers/transformer_skyreels_v2.py | 110 +++++++++--------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 9df5cefe5dff..8555393c33eb 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -32,4 +32,4 @@ class Transformer2DModelOutput(BaseOutput): """ sample: "torch.Tensor" # noqa: F821 - debug_tensors: Optional[Dict[str, Any]] = None + #debug_tensors: Optional[Dict[str, Any]] = None diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index f7b6f1f35c26..d7d53e120804 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -304,13 +304,13 @@ def forward( elif temb.dim() == 4: e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] - debug_dict = {} - debug_dict['x'] = hidden_states + #debug_dict = {} + #debug_dict['x'] = hidden_states # 1. Self-attention norm_x = self.norm1(hidden_states.float()) - debug_dict['norm_x'] = norm_x + #debug_dict['norm_x'] = norm_x norm_hidden_states = (norm_x * (1 + scale_msa) + shift_msa).type_as(hidden_states) - debug_dict['mul_add_add'] = norm_hidden_states + #debug_dict['mul_add_add'] = norm_hidden_states attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask ) @@ -329,7 +329,7 @@ def forward( ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) #debug_dict['cross_attn_ffn'] = hidden_states - return hidden_states, debug_dict + return hidden_states def set_ar_attention(self): self.attn1.processor.set_ar_attention() @@ -461,13 +461,13 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - debug_tensors = {} - debug_tensors["initial_hidden_states"] = hidden_states.detach().clone() - debug_tensors["initial_hidden_states_shape"] = list(hidden_states.shape) - debug_tensors["initial_timestep"] = timestep.detach().clone() - debug_tensors["initial_timestep_shape"] = list(timestep.shape) - debug_tensors["initial_encoder_hidden_states"] = encoder_hidden_states.detach().clone() - debug_tensors["initial_encoder_hidden_states_shape"] = list(encoder_hidden_states.shape) + #debug_tensors = {} + #debug_tensors["initial_hidden_states"] = hidden_states.detach().clone() + #debug_tensors["initial_hidden_states_shape"] = list(hidden_states.shape) + #debug_tensors["initial_timestep"] = timestep.detach().clone() + #debug_tensors["initial_timestep_shape"] = list(timestep.shape) + #debug_tensors["initial_encoder_hidden_states"] = encoder_hidden_states.detach().clone() + #debug_tensors["initial_encoder_hidden_states_shape"] = list(encoder_hidden_states.shape) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -491,15 +491,15 @@ def forward( post_patch_width = width // p_w rotary_emb = self.rope(hidden_states) - debug_tensors["rotary_emb"] = rotary_emb.detach().clone() - debug_tensors["rotary_emb_shape"] = list(rotary_emb.shape) + #debug_tensors["rotary_emb"] = rotary_emb.detach().clone() + #debug_tensors["rotary_emb_shape"] = list(rotary_emb.shape) hidden_states = self.patch_embedding(hidden_states) - debug_tensors["hidden_states_after_patch_embedding"] = hidden_states.detach().clone() - debug_tensors["hidden_states_after_patch_embedding_shape"] = list(hidden_states.shape) + #debug_tensors["hidden_states_after_patch_embedding"] = hidden_states.detach().clone() + #debug_tensors["hidden_states_after_patch_embedding_shape"] = list(hidden_states.shape) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) - debug_tensors["grid_sizes"] = grid_sizes.detach().clone() - debug_tensors["grid_sizes_shape"] = list(grid_sizes.shape) + #debug_tensors["grid_sizes"] = grid_sizes.detach().clone() + #debug_tensors["grid_sizes_shape"] = list(grid_sizes.shape) if self.config.flag_causal_attention: frame_num, height, width = grid_sizes @@ -514,18 +514,18 @@ def forward( causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) hidden_states = hidden_states.flatten(2).transpose(1, 2) - debug_tensors["hidden_states_after_flatten_transpose"] = hidden_states.detach().clone() - debug_tensors["hidden_states_after_flatten_transpose_shape"] = list(hidden_states.shape) + #debug_tensors["hidden_states_after_flatten_transpose"] = hidden_states.detach().clone() + #debug_tensors["hidden_states_after_flatten_transpose_shape"] = list(hidden_states.shape) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - debug_tensors["temb_after_time_embedding"] = temb.detach().clone() - debug_tensors["temb_after_time_embedding_shape"] = list(temb.shape) - debug_tensors["timestep_proj_after_time_projection"] = timestep_proj.detach().clone() - debug_tensors["timestep_proj_after_time_projection_shape"] = list(timestep_proj.shape) - debug_tensors["encoder_hidden_states_after_text_embedding"] = encoder_hidden_states.detach().clone() - debug_tensors["encoder_hidden_states_after_text_embedding_shape"] = list(encoder_hidden_states.shape) + #debug_tensors["temb_after_time_embedding"] = temb.detach().clone() + #debug_tensors["temb_after_time_embedding_shape"] = list(temb.shape) + #debug_tensors["timestep_proj_after_time_projection"] = timestep_proj.detach().clone() + #debug_tensors["timestep_proj_after_time_projection_shape"] = list(timestep_proj.shape) + #debug_tensors["encoder_hidden_states_after_text_embedding"] = encoder_hidden_states.detach().clone() + #debug_tensors["encoder_hidden_states_after_text_embedding_shape"] = list(encoder_hidden_states.shape) timestep_proj = timestep_proj.unflatten(-1, (6, -1)) @@ -535,9 +535,9 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for i, block in enumerate(self.blocks): - debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() - debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) - hidden_states, debug_dict = self._gradient_checkpointing_func( + #debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() + #debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) + hidden_states = self._gradient_checkpointing_func( block, hidden_states, encoder_hidden_states, @@ -545,15 +545,15 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) - debug_tensors[f"block_{i}_debug_dict"] = debug_dict - debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() - debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) + #debug_tensors[f"block_{i}_debug_dict"] = debug_dict + #debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() + #debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() - debug_tensors["fps_emb"] = fps_emb.detach().clone() - debug_tensors["fps_emb_shape"] = list(fps_emb.shape) + #debug_tensors["fps_emb"] = fps_emb.detach().clone() + #debug_tensors["fps_emb_shape"] = list(fps_emb.shape) timestep_proj = timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) if flag_df: @@ -562,8 +562,8 @@ def forward( ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) - debug_tensors["timestep_proj_after_fps_projection"] = timestep_proj.detach().clone() - debug_tensors["timestep_proj_after_fps_projection_shape"] = list(timestep_proj.shape) + #debug_tensors["timestep_proj_after_fps_projection"] = timestep_proj.detach().clone() + #debug_tensors["timestep_proj_after_fps_projection_shape"] = list(timestep_proj.shape) timestep_proj = timestep_proj.to(hidden_states.dtype) if flag_df: @@ -575,24 +575,24 @@ def forward( timestep_proj = timestep_proj.transpose(1, 2).contiguous() for i, block in enumerate(self.blocks): - debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() - debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) - hidden_states, debug_dict = block( + #debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() + #debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) + hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) - debug_tensors[f"block_{i}_debug_dict"] = debug_dict - return hidden_states, debug_dict - debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() - debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) - - debug_tensors["temb_before_output_modulation"] = temb.detach().clone() - debug_tensors["temb_before_output_modulation_shape"] = list(temb.shape) - debug_tensors["scale_shift_table"] = self.scale_shift_table.detach().clone() - debug_tensors["scale_shift_table_shape"] = list(self.scale_shift_table.shape) + #debug_tensors[f"block_{i}_debug_dict"] = debug_dict + #return hidden_states, debug_dict + #debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() + #debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) + + #debug_tensors["temb_before_output_modulation"] = temb.detach().clone() + #debug_tensors["temb_before_output_modulation_shape"] = list(temb.shape) + #debug_tensors["scale_shift_table"] = self.scale_shift_table.detach().clone() + #debug_tensors["scale_shift_table_shape"] = list(self.scale_shift_table.shape) if temb.dim() == 2: shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) elif temb.dim() == 3: @@ -607,20 +607,20 @@ def forward( scale = scale.to(hidden_states.device) hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) - debug_tensors["hidden_states_after_output_modulation"] = hidden_states.detach().clone() - debug_tensors["hidden_states_after_output_modulation_shape"] = list(hidden_states.shape) + #debug_tensors["hidden_states_after_output_modulation"] = hidden_states.detach().clone() + #debug_tensors["hidden_states_after_output_modulation_shape"] = list(hidden_states.shape) hidden_states = self.proj_out(hidden_states) - debug_tensors["hidden_states_after_proj_out"] = hidden_states.detach().clone() - debug_tensors["hidden_states_after_proj_out_shape"] = list(hidden_states.shape) + #debug_tensors["hidden_states_after_proj_out"] = hidden_states.detach().clone() + #debug_tensors["hidden_states_after_proj_out_shape"] = list(hidden_states.shape) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - debug_tensors["output_after_unpatchify"] = output.detach().clone() - debug_tensors["output_after_unpatchify_shape"] = list(output.shape) + #debug_tensors["output_after_unpatchify"] = output.detach().clone() + #debug_tensors["output_after_unpatchify_shape"] = list(output.shape) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer @@ -629,7 +629,7 @@ def forward( if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output, debug_tensors=debug_tensors) + return Transformer2DModelOutput(sample=output) def set_ar_attention(self, causal_block_size): self.register_to_config(num_frame_per_block=causal_block_size, flag_causal_attention=True) From 19dae163b05933bb23883b715718217c87c9e042 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 14:03:45 +0300 Subject: [PATCH 196/264] Add functionality to save processed video latents as a Safetensors file in `SkyReelsV2DiffusionForcingPipeline`. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a29359f89eff..edccf9bd590e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -991,6 +991,8 @@ def __call__( ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] + from safetensors.torch import save_file + save_file(video, "diffusers_latents.safetensors") video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 2947e52b857833f5a905f89b904294e41303686a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 14:25:55 +0300 Subject: [PATCH 197/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index edccf9bd590e..ec3d1d71bca2 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -992,7 +992,7 @@ def __call__( latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] from safetensors.torch import save_file - save_file(video, "diffusers_latents.safetensors") + save_file({"vae_decode": video}, "diffusers.safetensors") video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 324e7fefacbab25211777400aaae1a69fd7530e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 14:50:27 +0300 Subject: [PATCH 198/264] Add functionality to save output latents as a Safetensors file in `SkyReelsV2DiffusionForcingPipeline`. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index ec3d1d71bca2..c6ca48c5d3ef 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -981,6 +981,8 @@ def __call__( if not output_type == "latent": if overlap_history is None: latents = latents.to(self.vae.dtype) + from safetensors.torch import save_file + save_file({"output_latents": video}, "diffusers.safetensors") latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) @@ -991,8 +993,6 @@ def __call__( ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] - from safetensors.torch import save_file - save_file({"vae_decode": video}, "diffusers.safetensors") video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From abf59a532f3e12b09ea63940a43ea5deec4e8b5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 15:07:36 +0300 Subject: [PATCH 199/264] up --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index c6ca48c5d3ef..448d4b54d6ec 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -982,7 +982,7 @@ def __call__( if overlap_history is None: latents = latents.to(self.vae.dtype) from safetensors.torch import save_file - save_file({"output_latents": video}, "diffusers.safetensors") + save_file({"output_latents": latents}, "diffusers.safetensors") latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) From e227b3892f0a1bb668d7b7388afb5a396f1cb612 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 15:15:23 +0300 Subject: [PATCH 200/264] Remove additional commented-out debug tensor tracking from `SkyReelsV2TransformerBlock` and `SkyReelsV2Transformer3DModel` for cleaner code. --- .../transformers/transformer_skyreels_v2.py | 60 +------------------ 1 file changed, 3 insertions(+), 57 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index d7d53e120804..cf3ba76b3853 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -304,19 +304,12 @@ def forward( elif temb.dim() == 4: e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e] - #debug_dict = {} - #debug_dict['x'] = hidden_states - # 1. Self-attention - norm_x = self.norm1(hidden_states.float()) - #debug_dict['norm_x'] = norm_x - norm_hidden_states = (norm_x * (1 + scale_msa) + shift_msa).type_as(hidden_states) - #debug_dict['mul_add_add'] = norm_hidden_states + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) attn_output = self.attn1( hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask ) - #debug_dict['self_attn'] = attn_output hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) - #debug_dict['mul_add_compile'] = hidden_states # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) @@ -328,7 +321,6 @@ def forward( ) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) - #debug_dict['cross_attn_ffn'] = hidden_states return hidden_states def set_ar_attention(self): @@ -461,13 +453,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - #debug_tensors = {} - #debug_tensors["initial_hidden_states"] = hidden_states.detach().clone() - #debug_tensors["initial_hidden_states_shape"] = list(hidden_states.shape) - #debug_tensors["initial_timestep"] = timestep.detach().clone() - #debug_tensors["initial_timestep_shape"] = list(timestep.shape) - #debug_tensors["initial_encoder_hidden_states"] = encoder_hidden_states.detach().clone() - #debug_tensors["initial_encoder_hidden_states_shape"] = list(encoder_hidden_states.shape) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -491,15 +476,9 @@ def forward( post_patch_width = width // p_w rotary_emb = self.rope(hidden_states) - #debug_tensors["rotary_emb"] = rotary_emb.detach().clone() - #debug_tensors["rotary_emb_shape"] = list(rotary_emb.shape) hidden_states = self.patch_embedding(hidden_states) - #debug_tensors["hidden_states_after_patch_embedding"] = hidden_states.detach().clone() - #debug_tensors["hidden_states_after_patch_embedding_shape"] = list(hidden_states.shape) grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) - #debug_tensors["grid_sizes"] = grid_sizes.detach().clone() - #debug_tensors["grid_sizes_shape"] = list(grid_sizes.shape) if self.config.flag_causal_attention: frame_num, height, width = grid_sizes @@ -514,18 +493,10 @@ def forward( causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) hidden_states = hidden_states.flatten(2).transpose(1, 2) - #debug_tensors["hidden_states_after_flatten_transpose"] = hidden_states.detach().clone() - #debug_tensors["hidden_states_after_flatten_transpose_shape"] = list(hidden_states.shape) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) - #debug_tensors["temb_after_time_embedding"] = temb.detach().clone() - #debug_tensors["temb_after_time_embedding_shape"] = list(temb.shape) - #debug_tensors["timestep_proj_after_time_projection"] = timestep_proj.detach().clone() - #debug_tensors["timestep_proj_after_time_projection_shape"] = list(timestep_proj.shape) - #debug_tensors["encoder_hidden_states_after_text_embedding"] = encoder_hidden_states.detach().clone() - #debug_tensors["encoder_hidden_states_after_text_embedding_shape"] = list(encoder_hidden_states.shape) timestep_proj = timestep_proj.unflatten(-1, (6, -1)) @@ -535,8 +506,6 @@ def forward( # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: for i, block in enumerate(self.blocks): - #debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() - #debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) hidden_states = self._gradient_checkpointing_func( block, hidden_states, @@ -545,15 +514,10 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) - #debug_tensors[f"block_{i}_debug_dict"] = debug_dict - #debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() - #debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) fps_emb = self.fps_embedding(fps).float() - #debug_tensors["fps_emb"] = fps_emb.detach().clone() - #debug_tensors["fps_emb_shape"] = list(fps_emb.shape) timestep_proj = timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) if flag_df: @@ -562,8 +526,6 @@ def forward( ) else: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) - #debug_tensors["timestep_proj_after_fps_projection"] = timestep_proj.detach().clone() - #debug_tensors["timestep_proj_after_fps_projection_shape"] = list(timestep_proj.shape) timestep_proj = timestep_proj.to(hidden_states.dtype) if flag_df: @@ -575,8 +537,6 @@ def forward( timestep_proj = timestep_proj.transpose(1, 2).contiguous() for i, block in enumerate(self.blocks): - #debug_tensors[f"block_{i}_input_hidden_states"] = hidden_states.detach().clone() - #debug_tensors[f"block_{i}_input_hidden_states_shape"] = list(hidden_states.shape) hidden_states = block( hidden_states, encoder_hidden_states, @@ -584,15 +544,7 @@ def forward( rotary_emb, causal_mask if self.config.flag_causal_attention else None, ) - #debug_tensors[f"block_{i}_debug_dict"] = debug_dict - #return hidden_states, debug_dict - #debug_tensors[f"block_{i}_output_hidden_states"] = hidden_states.detach().clone() - #debug_tensors[f"block_{i}_output_hidden_states_shape"] = list(hidden_states.shape) - - #debug_tensors["temb_before_output_modulation"] = temb.detach().clone() - #debug_tensors["temb_before_output_modulation_shape"] = list(temb.shape) - #debug_tensors["scale_shift_table"] = self.scale_shift_table.detach().clone() - #debug_tensors["scale_shift_table_shape"] = list(self.scale_shift_table.shape) + if temb.dim() == 2: shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) elif temb.dim() == 3: @@ -607,20 +559,14 @@ def forward( scale = scale.to(hidden_states.device) hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) - #debug_tensors["hidden_states_after_output_modulation"] = hidden_states.detach().clone() - #debug_tensors["hidden_states_after_output_modulation_shape"] = list(hidden_states.shape) hidden_states = self.proj_out(hidden_states) - #debug_tensors["hidden_states_after_proj_out"] = hidden_states.detach().clone() - #debug_tensors["hidden_states_after_proj_out_shape"] = list(hidden_states.shape) hidden_states = hidden_states.reshape( batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 ) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) - #debug_tensors["output_after_unpatchify"] = output.detach().clone() - #debug_tensors["output_after_unpatchify_shape"] = list(output.shape) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From c6ef3cf3faa8138ac29cf0cd8cc5294a9a235c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 15:16:18 +0300 Subject: [PATCH 201/264] style --- src/diffusers/models/modeling_outputs.py | 2 -- src/diffusers/models/transformers/transformer_skyreels_v2.py | 1 - .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 1 + 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 8555393c33eb..96e86abff491 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional, Dict, Any from ..utils import BaseOutput @@ -32,4 +31,3 @@ class Transformer2DModelOutput(BaseOutput): """ sample: "torch.Tensor" # noqa: F821 - #debug_tensors: Optional[Dict[str, Any]] = None diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index cf3ba76b3853..eaa138873ea5 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -453,7 +453,6 @@ def forward( return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: - if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 448d4b54d6ec..3005536afeac 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -982,6 +982,7 @@ def __call__( if overlap_history is None: latents = latents.to(self.vae.dtype) from safetensors.torch import save_file + save_file({"output_latents": latents}, "diffusers.safetensors") latents_mean = ( torch.tensor(self.vae.config.latents_mean) From f359c773650feaec0a63e46985af2d6936bd4b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 15:20:34 +0300 Subject: [PATCH 202/264] cleansing --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 3005536afeac..a29359f89eff 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -981,9 +981,6 @@ def __call__( if not output_type == "latent": if overlap_history is None: latents = latents.to(self.vae.dtype) - from safetensors.torch import save_file - - save_file({"output_latents": latents}, "diffusers.safetensors") latents_mean = ( torch.tensor(self.vae.config.latents_mean) .view(1, self.vae.config.z_dim, 1, 1, 1) From 2fa1b38c68551e1e702a9a0305a29c553101f753 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 16:50:21 +0300 Subject: [PATCH 203/264] Update example documentation and parameters in `SkyReelsV2Pipeline`. Adjusted example code for loading models, modified default values for height, width, num_frames, and guidance_scale, and improved output video quality settings. --- .../skyreels_v2/pipeline_skyreels_v2.py | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 894c3c1f5ab6..1c49e0e86b98 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -43,34 +43,47 @@ import ftfy -EXAMPLE_DOC_STRING = """ +EXAMPLE_DOC_STRING = """\ Examples: - ```python + ```py >>> import torch + >>> from diffusers import ( + ... SkyReelsV2Pipeline, + ... FlowMatchUniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) >>> from diffusers.utils import export_to_video - >>> from diffusers import AutoencoderKLWan, SkyReelsV2Pipeline - >>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler - >>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers - >>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" - >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) - >>> pipe = SkyReelsV2Pipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) - >>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P - >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) - >>> pipe.to("cuda") + >>> # Load the pipeline + >>> vae = AutoencoderKLWan.from_pretrained( + ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, + ... ) + >>> pipe = SkyReelsV2Pipeline.from_pretrained( + ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, + ... ) + >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) + >>> pipe = pipe.to("cuda") + >>> pipe.transformer.set_ar_attention(causal_block_size=5) >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" >>> output = pipe( ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=720, - ... width=1280, - ... num_frames=81, - ... guidance_scale=5.0, + ... num_inference_steps=30, + ... height=544, + ... width=960, + ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos + ... addnoise_condition=20, # Improves consistency in long video generation ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=16) + >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` """ @@ -139,6 +152,7 @@ def __init__( self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, @@ -180,6 +194,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], @@ -261,6 +276,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.check_inputs def check_inputs( self, prompt, @@ -302,6 +318,7 @@ def check_inputs( ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents def prepare_latents( self, batch_size: int, @@ -364,11 +381,11 @@ def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, + height: int = 544, + width: int = 960, + num_frames: int = 97, num_inference_steps: int = 50, - guidance_scale: float = 5.0, + guidance_scale: float = 6.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -390,16 +407,16 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, defaults to `480`): + height (`int`, defaults to `544`): The height in pixels of the generated image. - width (`int`, defaults to `832`): + width (`int`, defaults to `960`): The width in pixels of the generated image. - num_frames (`int`, defaults to `81`): + num_frames (`int`, defaults to `97`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, defaults to `5.0`): + guidance_scale (`float`, defaults to `6.0`): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > From 4b0a775cca6d7bf3b77683722492c0f91c07a84d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 5 Jun 2025 17:06:51 +0300 Subject: [PATCH 204/264] Update shift parameter in example documentation and default values across SkyReels V2 pipelines. Adjusted shift values for I2V from 3.0 to 5.0 and updated related example code for consistency. --- .../skyreels_v2/pipeline_skyreels_v2.py | 18 ++--- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 8 +-- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 4 +- .../skyreels_v2/pipeline_skyreels_v2_i2v.py | 71 +++++++++---------- 4 files changed, 48 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 1c49e0e86b98..7a00756d0667 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -65,23 +65,20 @@ ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> shift = 8.0 # 8.0 for T2V, 5.0 for I2V >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") - >>> pipe.transformer.set_ar_attention(causal_block_size=5) >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." >>> output = pipe( ... prompt=prompt, - ... num_inference_steps=30, + ... num_inference_steps=50, ... height=544, ... width=960, ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V ... num_frames=97, - ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode) - ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos - ... addnoise_condition=20, # Improves consistency in long video generation + ... shift=8.0, ... ).frames[0] >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` @@ -399,6 +396,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + shift: float = 8.0, ): r""" The call function to the pipeline for generation. @@ -451,8 +449,10 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + max_sequence_length (`int`, *optional*, defaults to `512`): + The maximum sequence length for the text encoder. + shift (`float`, *optional*, defaults to `8.0`): + Flow matching scheduler parameter (**8.0 for T2V**, **5.0 for I2V**) Examples: @@ -517,7 +517,7 @@ def __call__( negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps # 5. Prepare latent variables diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index a33108bcd46c..051e5c8cecb0 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -72,7 +72,7 @@ ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> shift = 8.0 # 8.0 for T2V, 5.0 for I2V >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) @@ -595,7 +595,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, overlap_history: Optional[int] = None, - shift: float = 3.0, + shift: float = 5.0, addnoise_condition: float = 0, base_num_frames: int = 97, ar_step: int = 0, @@ -670,8 +670,8 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. - shift (`float`, *optional*, defaults to `3.0`): - Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) + shift (`float`, *optional*, defaults to `5.0`): + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 75f43229c883..79c7d5625acd 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -70,7 +70,7 @@ ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> shift = 8.0 # 8.0 for T2V, 5.0 for I2V >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) @@ -688,7 +688,7 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 832e6a8a73ec..bf46d1c9b25b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -45,51 +45,47 @@ import ftfy -EXAMPLE_DOC_STRING = """ +EXAMPLE_DOC_STRING = """\ Examples: - ```python + ```py >>> import torch - >>> import numpy as np - >>> from diffusers import AutoencoderKLWan, SkyReelsV2ImageToVideoPipeline - >>> from diffusers.utils import export_to_video, load_image - >>> from transformers import CLIPVisionModel - - >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers - >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" - >>> image_encoder = CLIPVisionModel.from_pretrained( - ... model_id, subfolder="image_encoder", torch_dtype=torch.float32 + >>> from diffusers import ( + ... SkyReelsV2ImageToVideoPipeline, + ... FlowMatchUniPCMultistepScheduler, + ... AutoencoderKLWan, + ... ) + >>> from diffusers.utils import export_to_video + >>> from PIL import Image + + >>> # Load the pipeline + >>> vae = AutoencoderKLWan.from_pretrained( + ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... subfolder="vae", + ... torch_dtype=torch.float32, ... ) - >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... vae=vae, + ... torch_dtype=torch.bfloat16, ... ) - >>> pipe.to("cuda") + >>> shift = 5.0 # 8.0 for T2V, 5.0 for I2V + >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) + >>> pipe = pipe.to("cuda") - >>> image = load_image( - ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" - ... ) - >>> max_area = 480 * 832 - >>> aspect_ratio = image.height / image.width - >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] - >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value - >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value - >>> image = image.resize((width, height)) - >>> prompt = ( - ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " - ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." - ... ) - >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." + >>> image = Image.open("path/to/image.png") >>> output = pipe( ... image=image, ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=height, - ... width=width, - ... num_frames=81, - ... guidance_scale=5.0, + ... num_inference_steps=50, + ... height=544, + ... width=960, + ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V + ... num_frames=97, + ... shift=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=16) + >>> export_to_video(output, "video.mp4", fps=24, quality=8) ``` """ @@ -503,6 +499,7 @@ def __call__( ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, + shift: float = 5.0, ): r""" The call function to the pipeline for generation. @@ -570,9 +567,7 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `5.0`): - The shift of the flow. - autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): - The dtype to use for the torch.amp.autocast. + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) Examples: Returns: @@ -647,7 +642,7 @@ def __call__( image_embeds = image_embeds.to(transformer_dtype) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) timesteps = self.scheduler.timesteps # 5. Prepare latent variables From ee56e4b6ab83aaf68a5263fe40fb617cfe322992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 14:27:01 +0300 Subject: [PATCH 205/264] Update example documentation in SkyReels V2 pipelines to include available model options and update model references for loading. Adjusted model names to reflect the latest versions across I2V, V2V, and T2V pipelines. --- .../pipelines/skyreels_v2/pipeline_skyreels_v2.py | 7 +++++-- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 8 ++++++-- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 8 ++++++-- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 8 ++++++-- .../pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py | 8 ++++++-- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 7a00756d0667..696775e3e7e9 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -55,13 +55,16 @@ >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> # Available models: + >>> # - /SkyReels-V2-T2V-14B-540P-Diffusers + >>> # - /SkyReels-V2-T2V-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... "/SkyReels-V2-T2V-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2Pipeline.from_pretrained( - ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... "/SkyReels-V2-T2V-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a29359f89eff..8649af1e65e5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -58,13 +58,17 @@ >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> # Available models: + >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 051e5c8cecb0..5dce13a66068 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -62,13 +62,17 @@ >>> from PIL import Image >>> # Load the pipeline + >>> # Available models: + >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 79c7d5625acd..07e44b94b6dc 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -60,13 +60,17 @@ >>> from diffusers.utils import export_to_video >>> # Load the pipeline + >>> # Available models: + >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-540P-Diffusers + >>> # - /SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-DF-1.3B-540P-Diffusers", + ... "/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index bf46d1c9b25b..aa3a8f2e6132 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -58,13 +58,17 @@ >>> from PIL import Image >>> # Load the pipeline + >>> # Available models: + >>> # - /SkyReels-V2-I2V-1.3B-540P-Diffusers + >>> # - /SkyReels-V2-I2V-14B-540P-Diffusers + >>> # - /SkyReels-V2-I2V-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... "/SkyReels-V2-I2V-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-1.3B-540P-Diffusers", + ... "/SkyReels-V2-I2V-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) From 1dadfc2e261e4c77fbbddc3bca075e4f39fbf0af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 15:51:07 +0300 Subject: [PATCH 206/264] Add test templates --- .../test_models_transformer_skyreels_v2.py | 84 ++++++ tests/pipelines/skyreels_v2/__init__.py | 0 .../pipelines/skyreels_v2/test_skyreels_v2.py | 156 +++++++++++ .../skyreels_v2/test_skyreels_v2_df.py | 155 +++++++++++ .../test_skyreels_v2_df_image_to_video.py | 210 +++++++++++++++ .../test_skyreels_v2_df_video_to_video.py | 146 +++++++++++ .../test_skyreels_v2_image_to_video.py | 248 ++++++++++++++++++ 7 files changed, 999 insertions(+) create mode 100644 tests/models/transformers/test_models_transformer_skyreels_v2.py create mode 100644 tests/pipelines/skyreels_v2/__init__.py create mode 100644 tests/pipelines/skyreels_v2/test_skyreels_v2.py create mode 100644 tests/pipelines/skyreels_v2/test_skyreels_v2_df.py create mode 100644 tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py create mode 100644 tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py create mode 100644 tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py new file mode 100644 index 000000000000..884f168308cc --- /dev/null +++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py @@ -0,0 +1,84 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import SkyReelsV2Transformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin + + +enable_full_determinism() + + +class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): + model_class = SkyReelsV2Transformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + text_encoder_embedding_dim = 16 + sequence_length = 12 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 1, 16, 16) + + @property + def output_shape(self): + return (4, 1, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "patch_size": (1, 2, 2), + "num_attention_heads": 2, + "attention_head_dim": 12, + "in_channels": 4, + "out_channels": 4, + "text_dim": 16, + "freq_dim": 256, + "ffn_dim": 32, + "num_layers": 2, + "cross_attn_norm": True, + "qk_norm": "rms_norm_across_heads", + "rope_max_seq_len": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"SkyReelsV2Transformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/skyreels_v2/__init__.py b/tests/pipelines/skyreels_v2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py new file mode 100644 index 000000000000..a162e6841d2d --- /dev/null +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py @@ -0,0 +1,156 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_accelerator, + slow, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = WanPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + # TODO: impl FlowDPMSolverMultistepScheduler + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = WanTransformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + +@slow +@require_torch_accelerator +class WanPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @unittest.skip("TODO: test needs to be implemented") + def test_Wanx(self): + pass diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py new file mode 100644 index 000000000000..976037d26a30 --- /dev/null +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py @@ -0,0 +1,155 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2Transformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_accelerator, + slow, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class SkyReelsV2DiffusionForcingPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SkyReelsV2DiffusionForcingPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + +@slow +@require_torch_accelerator +class SkyReelsV2DiffusionForcingPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @unittest.skip("TODO: test needs to be implemented") + def test_SkyReelsV2DiffusionForcingx(self): + pass diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py new file mode 100644 index 000000000000..7f8b8c2abe61 --- /dev/null +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py @@ -0,0 +1,210 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + AutoTokenizer, + T5EncoderModel, +) + +from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingImageToVideoPipeline, SkyReelsV2Transformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SkyReelsV2DiffusionForcingImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") + def test_inference_batch_single_identical(self): + pass + + +class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests): + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + pos_embed_seq_len=2 * (4 * 4 + 1), + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + last_image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "last_image": last_image, + "prompt": "dance monkey", + "negative_prompt": "negative", + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py new file mode 100644 index 000000000000..99d85cf2c367 --- /dev/null +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py @@ -0,0 +1,146 @@ +# Copyright 2025 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2Transformer3DModel, SkyReelsV2DiffusionForcingVideoToVideoPipeline +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + +from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class SkyReelsV2DiffusionForcingVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SkyReelsV2DiffusionForcingVideoToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["video", "prompt", "negative_prompt"]) + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(flow_shift=3.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=16, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + ) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + video = [Image.new("RGB", (16, 16))] * 17 + inputs = { + "video": video, + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "generator": generator, + "num_inference_steps": 4, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (17, 3, 16, 16)) + expected_video = torch.randn(17, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip( + "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors" + ) + def test_float16_inference(self): + pass + + @unittest.skip( + "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors" + ) + def test_save_load_float16(self): + pass diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py new file mode 100644 index 000000000000..a9dcf9755f72 --- /dev/null +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py @@ -0,0 +1,248 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + AutoTokenizer, + CLIPImageProcessor, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + T5EncoderModel, +) + +from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Transformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class SkyReelsV2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SkyReelsV2ImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=4, + projection_dim=4, + num_hidden_layers=2, + num_attention_heads=2, + image_size=32, + intermediate_size=16, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + torch.manual_seed(0) + image_processor = CLIPImageProcessor(crop_size=32, size=32) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "negative", # TODO + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 16, 16)) + expected_video = torch.randn(9, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + @unittest.skip("Test not supported") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass") + def test_inference_batch_single_identical(self): + pass + +# TODO: Is this FLF2V test necessary, because the original repo doesn't seem to have this functionality for this pipeline? +# or doesn't it have to be implemented? +class SkyReelsV2ImageToVideoPipelineFastTests(SkyReelsV2ImageToVideoPipelineFastTests): + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLWan( + base_dim=3, + z_dim=16, + dim_mult=[1, 1, 1, 1], + num_res_blocks=1, + temperal_downsample=[False, True, True], + ) + + torch.manual_seed(0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + transformer = SkyReelsV2Transformer3DModel( + patch_size=(1, 2, 2), + num_attention_heads=2, + attention_head_dim=12, + in_channels=36, + out_channels=16, + text_dim=32, + freq_dim=256, + ffn_dim=32, + num_layers=2, + cross_attn_norm=True, + qk_norm="rms_norm_across_heads", + rope_max_seq_len=32, + image_dim=4, + pos_embed_seq_len=2 * (4 * 4 + 1), + ) + + torch.manual_seed(0) + image_encoder_config = CLIPVisionConfig( + hidden_size=4, + projection_dim=4, + num_hidden_layers=2, + num_attention_heads=2, + image_size=4, + intermediate_size=16, + patch_size=1, + ) + image_encoder = CLIPVisionModelWithProjection(image_encoder_config) + + torch.manual_seed(0) + image_processor = CLIPImageProcessor(crop_size=4, size=4) + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "image_encoder": image_encoder, + "image_processor": image_processor, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image_height = 16 + image_width = 16 + image = Image.new("RGB", (image_width, image_height)) + last_image = Image.new("RGB", (image_width, image_height)) + inputs = { + "image": image, + "last_image": last_image, + "prompt": "dance monkey", + "negative_prompt": "negative", + "height": image_height, + "width": image_width, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs From 619a571f6c84a17c7261ccd7c24d797258ea4cc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 15:52:33 +0300 Subject: [PATCH 207/264] style --- .../skyreels_v2/pipeline_skyreels_v2_i2v.py | 14 +++++++------- tests/pipelines/skyreels_v2/test_skyreels_v2_df.py | 7 ++++++- .../test_skyreels_v2_df_image_to_video.py | 7 ++++++- .../test_skyreels_v2_df_video_to_video.py | 7 ++++++- .../skyreels_v2/test_skyreels_v2_image_to_video.py | 8 +++++++- 5 files changed, 32 insertions(+), 11 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index aa3a8f2e6132..46b1f9e51dfc 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -127,7 +127,7 @@ def retrieve_latents( class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): r""" - Pipeline for image-to-video generation using SkyReels-V2. + Pipeline for Image-to-Video (i2v) generation using SkyReels-V2. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -483,9 +483,9 @@ def __call__( image: PipelineImageInput, prompt: Union[str, List[str]] = None, negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, + height: int = 544, + width: int = 960, + num_frames: int = 97, num_inference_steps: int = 50, guidance_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, @@ -518,11 +518,11 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - height (`int`, defaults to `480`): + height (`int`, defaults to `544`): The height of the generated video. - width (`int`, defaults to `832`): + width (`int`, defaults to `960`): The width of the generated video. - num_frames (`int`, defaults to `81`): + num_frames (`int`, defaults to `97`): The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py index 976037d26a30..18fe0196b30f 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py @@ -19,7 +19,12 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2Transformer3DModel +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2Transformer3DModel, +) from diffusers.utils.testing_utils import ( enable_full_determinism, require_torch_accelerator, diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py index 7f8b8c2abe61..470a4ed103d9 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py @@ -22,7 +22,12 @@ T5EncoderModel, ) -from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingImageToVideoPipeline, SkyReelsV2Transformer3DModel +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2DiffusionForcingImageToVideoPipeline, + SkyReelsV2Transformer3DModel, +) from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py index 99d85cf2c367..37a0ce0adec4 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py @@ -19,7 +19,12 @@ from PIL import Image from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2Transformer3DModel, SkyReelsV2DiffusionForcingVideoToVideoPipeline +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2DiffusionForcingVideoToVideoPipeline, + SkyReelsV2Transformer3DModel, +) from diffusers.utils.testing_utils import ( enable_full_determinism, ) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py index a9dcf9755f72..e8caf222b555 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py @@ -25,7 +25,12 @@ T5EncoderModel, ) -from diffusers import AutoencoderKLWan, FlowMatchUniPCMultistepScheduler, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Transformer3DModel +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2ImageToVideoPipeline, + SkyReelsV2Transformer3DModel, +) from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -160,6 +165,7 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): pass + # TODO: Is this FLF2V test necessary, because the original repo doesn't seem to have this functionality for this pipeline? # or doesn't it have to be implemented? class SkyReelsV2ImageToVideoPipelineFastTests(SkyReelsV2ImageToVideoPipelineFastTests): From 974fa001bc3a5f79c339bed66f27c12d29e1f435 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 15:58:45 +0300 Subject: [PATCH 208/264] Add docs template --- docs/source/en/api/pipelines/skyreels_v2.md | 256 ++++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 docs/source/en/api/pipelines/skyreels_v2.md diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md new file mode 100644 index 000000000000..e0be071921d6 --- /dev/null +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -0,0 +1,256 @@ + + +
+
+ + LoRA + +
+
+ +# Wan2.1 + +[Wan2.1](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf) is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics. + +You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization. + +> [!TIP] +> Click on the Wan2.1 models in the right sidebar for more examples of video generation. + +The example below demonstrates how to generate a video from text optimized for memory or inference speed. + + + + +Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. + +The Wan2.1 text-to-video model below requires ~13GB of VRAM. + +```py +# pip install ftfy +import torch +import numpy as np +from diffusers import AutoModel, WanPipeline +from diffusers.quantizers import PipelineQuantizationConfig +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video, load_image +from transformers import UMT5EncoderModel + +text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) + +# group-offloading +onload_device = torch.device("cuda") +offload_device = torch.device("cpu") +apply_group_offloading(text_encoder, + onload_device=onload_device, + offload_device=offload_device, + offload_type="block_level", + num_blocks_per_group=4 +) +transformer.enable_group_offload( + onload_device=onload_device, + offload_device=offload_device, + offload_type="leaf_level", + use_stream=True +) + +pipeline = WanPipeline.from_pretrained( + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +prompt = """ +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +shadows and warm highlights. Medium composition, front view, low angle, with depth of field. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=81, + guidance_scale=5.0, +).frames[0] +export_to_video(output, "output.mp4", fps=16) +``` + + + + +[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. + +```py +# pip install ftfy +import torch +import numpy as np +from diffusers import AutoModel, WanPipeline +from diffusers.hooks.group_offloading import apply_group_offloading +from diffusers.utils import export_to_video, load_image +from transformers import UMT5EncoderModel + +text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) + +pipeline = WanPipeline.from_pretrained( + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + vae=vae, + transformer=transformer, + text_encoder=text_encoder, + torch_dtype=torch.bfloat16 +) +pipeline.to("cuda") + +# torch.compile +pipeline.transformer.to(memory_format=torch.channels_last) +pipeline.transformer = torch.compile( + pipeline.transformer, mode="max-autotune", fullgraph=True +) + +prompt = """ +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +shadows and warm highlights. Medium composition, front view, low angle, with depth of field. +""" +negative_prompt = """ +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards +""" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=81, + guidance_scale=5.0, +).frames[0] +export_to_video(output, "output.mp4", fps=16) +``` + + + + +## Notes + +- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. + +
+ Show example code + + ```py + # pip install ftfy + import torch + from diffusers import AutoModel, WanPipeline + from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + from diffusers.utils import export_to_video + + vae = AutoModel.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 + ) + pipeline = WanPipeline.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", vae=vae, torch_dtype=torch.bfloat16 + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config( + pipeline.scheduler.config, flow_shift=5.0 + ) + pipeline.to("cuda") + + pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie") + pipeline.set_adapters("steamboat-willie") + + pipeline.enable_model_cpu_offload() + + # use "steamboat willie style" to trigger the LoRA + prompt = """ + steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, + revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in + for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. + Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic + shadows and warm highlights. Medium composition, front view, low angle, with depth of field. + """ + + output = pipeline( + prompt=prompt, + num_frames=81, + guidance_scale=5.0, + ).frames[0] + export_to_video(output, "output.mp4", fps=16) + ``` + +
+ +- [`WanTransformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`]. + +
+ Show example code + + ```py + # pip install ftfy + import torch + from diffusers import WanPipeline, AutoModel + + vae = AutoModel.from_single_file( + "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" + ) + transformer = AutoModel.from_single_file( + "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors", + torch_dtype=torch.bfloat16 + ) + pipeline = WanPipeline.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + vae=vae, + transformer=transformer, + torch_dtype=torch.bfloat16 + ) + ``` + +
+ +- Set the [`AutoencoderKLWan`] dtype to `torch.float32` for better decoding quality. + +- The number of frames per second (fps) or `k` should be calculated by `4 * k + 1`. + +- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images. + +## WanPipeline + +[[autodoc]] WanPipeline + - all + - __call__ + +## WanImageToVideoPipeline + +[[autodoc]] WanImageToVideoPipeline + - all + - __call__ + +## WanPipelineOutput + +[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput \ No newline at end of file From 6e84a82d7f2c28bd823d066b5d3c72399247830e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 16:58:22 +0300 Subject: [PATCH 209/264] Add SkyReels V2 Diffusion Forcing Video-to-Video Pipeline to imports --- src/diffusers/__init__.py | 2 ++ src/diffusers/pipelines/__init__.py | 2 ++ src/diffusers/pipelines/skyreels_v2/__init__.py | 4 ++++ 3 files changed, 8 insertions(+) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1d648faa5548..c7179d04b9db 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -450,6 +450,7 @@ "ShapEImg2ImgPipeline", "ShapEPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", + "SkyReelsV2DiffusionForcingVideoToVideoPipeline", "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", @@ -1044,6 +1045,7 @@ ShapEImg2ImgPipeline, ShapEPipeline, SkyReelsV2DiffusionForcingImageToVideoPipeline, + SkyReelsV2DiffusionForcingVideoToVideoPipeline, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 68dd72499249..d3e8b428dcc3 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -375,6 +375,7 @@ _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", + "SkyReelsV2DiffusionForcingVideoToVideoPipeline", "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", ] @@ -844,6 +845,7 @@ from .skyreels_v2 import ( SkyReelsV2DiffusionForcingImageToVideoPipeline, SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2DiffusionForcingVideoToVideoPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, ) diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py index 602f034ca1e2..84d2a2dd3500 100644 --- a/src/diffusers/pipelines/skyreels_v2/__init__.py +++ b/src/diffusers/pipelines/skyreels_v2/__init__.py @@ -27,6 +27,9 @@ _import_structure["pipeline_skyreels_v2_diffusion_forcing_i2v"] = [ "SkyReelsV2DiffusionForcingImageToVideoPipeline" ] + _import_structure["pipeline_skyreels_v2_diffusion_forcing_v2v"] = [ + "SkyReelsV2DiffusionForcingVideoToVideoPipeline" + ] _import_structure["pipeline_skyreels_v2_i2v"] = ["SkyReelsV2ImageToVideoPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -39,6 +42,7 @@ from .pipeline_skyreels_v2 import SkyReelsV2Pipeline from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline from .pipeline_skyreels_v2_diffusion_forcing_i2v import SkyReelsV2DiffusionForcingImageToVideoPipeline + from .pipeline_skyreels_v2_diffusion_forcing_v2v import SkyReelsV2DiffusionForcingVideoToVideoPipeline from .pipeline_skyreels_v2_i2v import SkyReelsV2ImageToVideoPipeline else: From 8758da776c693a2205e11182b5c9983f5785a05a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 17:01:10 +0300 Subject: [PATCH 210/264] style --- src/diffusers/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c7179d04b9db..6a8c11894c60 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -450,8 +450,8 @@ "ShapEImg2ImgPipeline", "ShapEPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", - "SkyReelsV2DiffusionForcingVideoToVideoPipeline", "SkyReelsV2DiffusionForcingPipeline", + "SkyReelsV2DiffusionForcingVideoToVideoPipeline", "SkyReelsV2ImageToVideoPipeline", "SkyReelsV2Pipeline", "StableAudioPipeline", @@ -1045,8 +1045,8 @@ ShapEImg2ImgPipeline, ShapEPipeline, SkyReelsV2DiffusionForcingImageToVideoPipeline, - SkyReelsV2DiffusionForcingVideoToVideoPipeline, SkyReelsV2DiffusionForcingPipeline, + SkyReelsV2DiffusionForcingVideoToVideoPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, StableAudioPipeline, From 568c59e9a6394d5e7a1effbaf6e33cd10ed1814d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 6 Jun 2025 17:01:54 +0300 Subject: [PATCH 211/264] fix-copies --- .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6f132d34def3..e4c99215992b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1727,6 +1727,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SkyReelsV2DiffusionForcingVideoToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class SkyReelsV2ImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 7759617d32df18ae754ebdebab6ebdc46900dc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 14:38:06 +0300 Subject: [PATCH 212/264] convert i2v 1.3b --- scripts/convert_skyreelsv2_to_diffusers.py | 113 +++++++++++++++++++-- 1 file changed, 106 insertions(+), 7 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 29591d2da2fc..00611ce6c6b7 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -139,6 +139,106 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + elif model_type == "SkyReels-V2-T2V-14B-720P": + config = { + "model_id": "Skywork/SkyReels-V2-T2V-14B-720P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "inject_sample_info": False, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReels-V2-T2V-14B-540P": + config = { + "model_id": "Skywork/SkyReels-V2-T2V-14B-540P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "inject_sample_info": False, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReels-V2-I2V-1.3B-540P": + config = { + "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 12, + "inject_sample_info": False, + "num_layers": 30, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReels-V2-I2V-14B-540P": + config = { + "model_id": "Skywork/SkyReels-V2-I2V-14B-540P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "inject_sample_info": False, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } + elif model_type == "SkyReels-V2-I2V-14B-720P": + config = { + "model_id": "Skywork/SkyReels-V2-I2V-14B-720P", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "inject_sample_info": False, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + }, + } return config @@ -147,9 +247,9 @@ def convert_transformer(model_type: str): diffusers_config = config["diffusers_config"] model_id = config["model_id"] - if model_type == "SkyReels-V2-DF-1.3B-540P": + if model_type in ("SkyReels-V2-DF-1.3B-540P", "SkyReels-V2-I2V-1.3B-540P"): original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) - elif model_type in ["SkyReels-V2-DF-14B-720P", "SkyReels-V2-DF-14B-540P"]: + elif model_type in ("SkyReels-V2-DF-14B-720P", "SkyReels-V2-DF-14B-540P"): os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) if model_type == "SkyReels-V2-DF-14B-720P": @@ -409,7 +509,7 @@ def get_args(): transformer = convert_transformer(args.model_type).to(dtype=dtype) vae = convert_vae() - text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") scheduler = FlowMatchUniPCMultistepScheduler( prediction_type="flow_prediction", @@ -418,12 +518,11 @@ def get_args(): if "I2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 - ) + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( transformer=transformer, - text_encoder=text_encoder, + text_encoder=None, tokenizer=tokenizer, vae=vae, scheduler=scheduler, @@ -433,7 +532,7 @@ def get_args(): else: pipe = SkyReelsV2DiffusionForcingPipeline( transformer=transformer, - text_encoder=text_encoder, + text_encoder=None, tokenizer=tokenizer, vae=vae, scheduler=scheduler, From 943cd3e88fe0a8ce56d40ab52e2cf14a793e499f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 15:01:20 +0300 Subject: [PATCH 213/264] Update transformer configuration to include `image_dim` for SkyReels V2 models and refactor imports to use `SkyReelsV2Transformer3DModel`. --- scripts/convert_skyreelsv2_to_diffusers.py | 3 +++ .../pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py | 6 +++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 00611ce6c6b7..d58164679337 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -197,6 +197,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "text_dim": 4096, + "image_dim": 1280, }, } elif model_type == "SkyReels-V2-I2V-14B-540P": @@ -217,6 +218,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "text_dim": 4096, + "image_dim": 1280, }, } elif model_type == "SkyReels-V2-I2V-14B-720P": @@ -237,6 +239,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "patch_size": [1, 2, 2], "qk_norm": "rms_norm_across_heads", "text_dim": 4096, + "image_dim": 1280, }, } return config diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 46b1f9e51dfc..48d69df74e20 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -23,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -144,7 +144,7 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): the [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large) variant. - transformer ([`WanTransformer3DModel`]): + transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the input latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -161,7 +161,7 @@ def __init__( text_encoder: UMT5EncoderModel, image_encoder: CLIPVisionModel, image_processor: CLIPImageProcessor, - transformer: WanTransformer3DModel, + transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, ): From 993d19d57a8eca13a9636642f548c5fdc0c68138 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 16:46:45 +0300 Subject: [PATCH 214/264] Refactor transformer import in SkyReels V2 pipeline to use `SkyReelsV2Transformer3DModel` for consistency. --- src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index 696775e3e7e9..aed899325e32 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -21,7 +21,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import WanLoraLoaderMixin -from ...models import AutoencoderKLWan, WanTransformer3DModel +from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor @@ -119,7 +119,7 @@ class SkyReelsV2Pipeline(DiffusionPipeline, WanLoraLoaderMixin): text_encoder ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanTransformer3DModel`]): + transformer ([`SkyReelsV2Transformer3DModel`]): Conditional Transformer to denoise the input latents. scheduler ([`FlowMatchUniPCMultistepScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. @@ -134,7 +134,7 @@ def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanTransformer3DModel, + transformer: SkyReelsV2Transformer3DModel, vae: AutoencoderKLWan, scheduler: FlowMatchUniPCMultistepScheduler, ): From 7387e52e9b76ccc3c677f180b103843b3a86c7f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 16:52:32 +0300 Subject: [PATCH 215/264] Update transformer configuration in SkyReels V2 to increase `in_channels` from 16 to 36 for i2v conf. --- scripts/convert_skyreelsv2_to_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index d58164679337..9d787fe95cef 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -189,7 +189,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "eps": 1e-06, "ffn_dim": 8960, "freq_dim": 256, - "in_channels": 16, + "in_channels": 36, "num_attention_heads": 12, "inject_sample_info": False, "num_layers": 30, @@ -210,7 +210,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "eps": 1e-06, "ffn_dim": 13824, "freq_dim": 256, - "in_channels": 16, + "in_channels": 36, "num_attention_heads": 40, "inject_sample_info": False, "num_layers": 40, @@ -231,7 +231,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "eps": 1e-06, "ffn_dim": 13824, "freq_dim": 256, - "in_channels": 16, + "in_channels": 36, "num_attention_heads": 40, "inject_sample_info": False, "num_layers": 40, From 96af7eb0641e6d837325afb347ff7401415ad38d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 17:18:24 +0300 Subject: [PATCH 216/264] Update transformer configuration in SkyReels V2 to set `added_kv_proj_dim` values for different model types. --- scripts/convert_skyreelsv2_to_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 9d787fe95cef..e388993b426c 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -183,7 +183,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: config = { "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P", "diffusers_config": { - "added_kv_proj_dim": None, + "added_kv_proj_dim": 1536, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, @@ -204,7 +204,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: config = { "model_id": "Skywork/SkyReels-V2-I2V-14B-540P", "diffusers_config": { - "added_kv_proj_dim": None, + "added_kv_proj_dim": 5120, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, @@ -225,7 +225,7 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: config = { "model_id": "Skywork/SkyReels-V2-I2V-14B-720P", "diffusers_config": { - "added_kv_proj_dim": None, + "added_kv_proj_dim": 5120, "attention_head_dim": 128, "cross_attn_norm": True, "eps": 1e-06, From a6a733723725b58c6be71bba2c4f326c84d66ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 17:27:51 +0300 Subject: [PATCH 217/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index e388993b426c..8876bd4b68c7 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -510,9 +510,9 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] - transformer = convert_transformer(args.model_type).to(dtype=dtype) - vae = convert_vae() - #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + #transformer = convert_transformer(args.model_type).to(dtype=dtype) + #vae = convert_vae() + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") scheduler = FlowMatchUniPCMultistepScheduler( prediction_type="flow_prediction", @@ -520,17 +520,17 @@ def get_args(): ) if "I2V" in args.model_type: - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + #image_encoder = CLIPVisionModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + #image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( - transformer=transformer, - text_encoder=None, + #transformer=None, + text_encoder=text_encoder, tokenizer=tokenizer, - vae=vae, + #vae=None, scheduler=scheduler, - image_encoder=image_encoder, - image_processor=image_processor, + #image_encoder=None, + #image_processor=None, ) else: pipe = SkyReelsV2DiffusionForcingPipeline( From 72ad13cf2454011919ff8d9ba982a64a2213cf11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 17:32:24 +0300 Subject: [PATCH 218/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 8876bd4b68c7..0863042a2109 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -524,13 +524,13 @@ def get_args(): # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") #image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( - #transformer=None, + transformer=None, text_encoder=text_encoder, tokenizer=tokenizer, - #vae=None, + vae=None, scheduler=scheduler, - #image_encoder=None, - #image_processor=None, + image_encoder=None, + image_processor=None, ) else: pipe = SkyReelsV2DiffusionForcingPipeline( From d069905042aed78ab01813bd2b2e21b327406940 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 17:51:48 +0300 Subject: [PATCH 219/264] up --- scripts/convert_skyreelsv2_to_diffusers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 0863042a2109..92f8c00c6d2f 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -511,7 +511,7 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] #transformer = convert_transformer(args.model_type).to(dtype=dtype) - #vae = convert_vae() + vae = convert_vae() text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") scheduler = FlowMatchUniPCMultistepScheduler( @@ -520,17 +520,17 @@ def get_args(): ) if "I2V" in args.model_type: - #image_encoder = CLIPVisionModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - #image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( transformer=None, text_encoder=text_encoder, tokenizer=tokenizer, - vae=None, + vae=vae, scheduler=scheduler, - image_encoder=None, - image_processor=None, + image_encoder=image_encoder, + image_processor=image_processor, ) else: pipe = SkyReelsV2DiffusionForcingPipeline( From 8142720efcf8583f9838f99f4af7328769ce37be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 18:37:46 +0300 Subject: [PATCH 220/264] Add SkyReelsV2Pipeline support for T2V model type in conversion script --- scripts/convert_skyreelsv2_to_diffusers.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 92f8c00c6d2f..cd72f63a7b77 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -14,6 +14,7 @@ FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, + SkyReelsV2Pipeline, SkyReelsV2Transformer3DModel, ) @@ -532,7 +533,15 @@ def get_args(): image_encoder=image_encoder, image_processor=image_processor, ) - else: + elif "T2V" in args.model_type: + pipe = SkyReelsV2Pipeline( + transformer=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ) + elif "DF" in args.model_type: pipe = SkyReelsV2DiffusionForcingPipeline( transformer=transformer, text_encoder=None, From 326b6ed40f7252046ad37113e4c423b2178de26d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 19:03:43 +0300 Subject: [PATCH 221/264] upp --- scripts/convert_skyreelsv2_to_diffusers.py | 48 +++++++++++----------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index cd72f63a7b77..45df56acf4eb 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -253,7 +253,7 @@ def convert_transformer(model_type: str): if model_type in ("SkyReels-V2-DF-1.3B-540P", "SkyReels-V2-I2V-1.3B-540P"): original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) - elif model_type in ("SkyReels-V2-DF-14B-720P", "SkyReels-V2-DF-14B-540P"): + else: os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) if model_type == "SkyReels-V2-DF-14B-720P": @@ -511,35 +511,35 @@ def get_args(): transformer = None dtype = DTYPE_MAPPING[args.dtype] - #transformer = convert_transformer(args.model_type).to(dtype=dtype) - vae = convert_vae() - text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") - tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") - scheduler = FlowMatchUniPCMultistepScheduler( - prediction_type="flow_prediction", - num_train_timesteps=1000, - ) + transformer = convert_transformer(args.model_type).to(dtype=dtype) + #vae = convert_vae() + #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + #tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + #scheduler = FlowMatchUniPCMultistepScheduler( + # prediction_type="flow_prediction", + # num_train_timesteps=1000, + #) if "I2V" in args.model_type: - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + #image_encoder = CLIPVisionModelWithProjection.from_pretrained( + # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + #image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( - transformer=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - scheduler=scheduler, - image_encoder=image_encoder, - image_processor=image_processor, + transformer=transformer, + text_encoder=None, + tokenizer=None, + vae=None, + scheduler=None, + image_encoder=None, + image_processor=None, ) elif "T2V" in args.model_type: pipe = SkyReelsV2Pipeline( - transformer=None, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=vae, - scheduler=scheduler, + transformer=transformer, + text_encoder=None, + tokenizer=None, + vae=None, + scheduler=None, ) elif "DF" in args.model_type: pipe = SkyReelsV2DiffusionForcingPipeline( From a4622226fb368378988c3d690ccfbd5bd104f43f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 19:46:18 +0300 Subject: [PATCH 222/264] Refactor model type checks in conversion script to use substring matching for improved flexibility --- scripts/convert_skyreelsv2_to_diffusers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 45df56acf4eb..b74f3a2fe88d 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -251,15 +251,15 @@ def convert_transformer(model_type: str): diffusers_config = config["diffusers_config"] model_id = config["model_id"] - if model_type in ("SkyReels-V2-DF-1.3B-540P", "SkyReels-V2-I2V-1.3B-540P"): + if "1.3B" in model_type: original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors")) else: os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) - if model_type == "SkyReels-V2-DF-14B-720P": + if "720P" in model_type: top_shard = 6 model_name = "diffusion_pytorch_model" - elif model_type == "SkyReels-V2-DF-14B-540P": + elif "540P" in model_type: top_shard = 12 model_name = "model" From a8c057f038e7f2e7d9b5f1026fedb34d2b77ad47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 20:04:02 +0300 Subject: [PATCH 223/264] upp --- scripts/convert_skyreelsv2_to_diffusers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index b74f3a2fe88d..e780d35fd9a7 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -257,10 +257,10 @@ def convert_transformer(model_type: str): os.makedirs(model_type, exist_ok=True) model_dir = pathlib.Path(model_type) if "720P" in model_type: - top_shard = 6 + top_shard = 7 if "I2V" in model_type else 6 model_name = "diffusion_pytorch_model" elif "540P" in model_type: - top_shard = 12 + top_shard = 14 if "I2V" in model_type else 12 model_name = "model" for i in range(1, top_shard + 1): From 6bdfbcf38c2afdcb4ee15a7534498f00b8e8dace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 20:10:59 +0300 Subject: [PATCH 224/264] Fix shard path formatting in conversion script to accommodate varying model types by dynamically adjusting zero padding. --- scripts/convert_skyreelsv2_to_diffusers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index e780d35fd9a7..cf843fc3af46 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -258,13 +258,14 @@ def convert_transformer(model_type: str): model_dir = pathlib.Path(model_type) if "720P" in model_type: top_shard = 7 if "I2V" in model_type else 6 + zeros = "0" * (4 if "I2V" or "T2V" in model_type else 3) model_name = "diffusion_pytorch_model" elif "540P" in model_type: top_shard = 14 if "I2V" in model_type else 12 model_name = "model" for i in range(1, top_shard + 1): - shard_path = f"{model_name}-{i:05d}-of-000{top_shard}.safetensors" + shard_path = f"{model_name}-{i:05d}-of-{zeros}{top_shard}.safetensors" hf_hub_download(model_id, shard_path, local_dir=model_dir) original_state_dict = load_sharded_safetensors(model_dir) From db74f87e2d6946053e0f9c0c7f74ca3515bff73c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 7 Jun 2025 21:11:57 +0300 Subject: [PATCH 225/264] Update sharded safetensors loading logic in conversion script to use substring matching for model directory checks --- scripts/convert_skyreelsv2_to_diffusers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index cf843fc3af46..9cdd8d38fb53 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -69,7 +69,7 @@ def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) - def load_sharded_safetensors(dir: pathlib.Path): - if "SkyReels-V2-DF-14B-720P" in str(dir): + if "720P" in str(dir): file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors")) else: file_paths = list(dir.glob("model*.safetensors")) From cc698b6195672171349b1bda6fdc88d227e59ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 11:41:46 +0300 Subject: [PATCH 226/264] Update scheduler parameters in SkyReels V2 test files for consistency across image and video pipelines --- .../pipelines/skyreels_v2/test_skyreels_v2.py | 20 +++++++++++-------- .../skyreels_v2/test_skyreels_v2_df.py | 2 +- .../test_skyreels_v2_df_image_to_video.py | 4 ++-- .../test_skyreels_v2_df_video_to_video.py | 2 +- .../test_skyreels_v2_image_to_video.py | 4 ++-- 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py index a162e6841d2d..c2e49ec98a15 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py @@ -19,7 +19,12 @@ import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanPipeline, WanTransformer3DModel +from diffusers import ( + AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, + SkyReelsV2Pipeline, + SkyReelsV2Transformer3DModel, +) from diffusers.utils.testing_utils import ( enable_full_determinism, require_torch_accelerator, @@ -35,8 +40,8 @@ enable_full_determinism() -class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase): - pipeline_class = WanPipeline +class SkyReelsV2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = SkyReelsV2Pipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = TEXT_TO_IMAGE_IMAGE_PARAMS @@ -65,13 +70,12 @@ def get_dummy_components(self): ) torch.manual_seed(0) - # TODO: impl FlowDPMSolverMultistepScheduler - scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=8.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) - transformer = WanTransformer3DModel( + transformer = SkyReelsV2Transformer3DModel( patch_size=(1, 2, 2), num_attention_heads=2, attention_head_dim=12, @@ -138,7 +142,7 @@ def test_attention_slicing_forward_pass(self): @slow @require_torch_accelerator -class WanPipelineIntegrationTests(unittest.TestCase): +class SkyReelsV2PipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." def setUp(self): @@ -152,5 +156,5 @@ def tearDown(self): torch.cuda.empty_cache() @unittest.skip("TODO: test needs to be implemented") - def test_Wanx(self): + def test_SkyReelsV2x(self): pass diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py index 18fe0196b30f..48aec05eeee6 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py @@ -70,7 +70,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=8.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py index 470a4ed103d9..2053ad21970a 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py @@ -67,7 +67,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=5.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") @@ -158,7 +158,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=5.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py index 37a0ce0adec4..e252ef3e1b48 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py @@ -67,7 +67,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(flow_shift=3.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=5.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py index e8caf222b555..c3b42f1c2472 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py @@ -70,7 +70,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=5.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") @@ -180,7 +180,7 @@ def get_dummy_components(self): ) torch.manual_seed(0) - scheduler = FlowMatchUniPCMultistepScheduler(shift=7.0) + scheduler = FlowMatchUniPCMultistepScheduler(shift=5.0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") From 9a269a2e55380a48bb205bed4b357a57c12868fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 11:45:52 +0300 Subject: [PATCH 227/264] Refactor conversion script to initialize text encoder, tokenizer, and scheduler for SkyReels pipelines, enhancing model integration --- scripts/convert_skyreelsv2_to_diffusers.py | 50 +++++++++++----------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 9cdd8d38fb53..0a437e316081 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -7,15 +7,15 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel +from transformers import AutoTokenizer, CLIPVisionModelWithProjection, AutoProcessor, UMT5EncoderModel from diffusers import ( AutoencoderKLWan, - FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, SkyReelsV2Transformer3DModel, + FlowMatchUniPCMultistepScheduler, ) @@ -513,39 +513,39 @@ def get_args(): dtype = DTYPE_MAPPING[args.dtype] transformer = convert_transformer(args.model_type).to(dtype=dtype) - #vae = convert_vae() - #text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") - #tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") - #scheduler = FlowMatchUniPCMultistepScheduler( - # prediction_type="flow_prediction", - # num_train_timesteps=1000, - #) + vae = convert_vae() + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") + scheduler = FlowMatchUniPCMultistepScheduler( + prediction_type="flow_prediction", + num_train_timesteps=1000, + ) if "I2V" in args.model_type: - #image_encoder = CLIPVisionModelWithProjection.from_pretrained( - # "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - #image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( transformer=transformer, - text_encoder=None, - tokenizer=None, - vae=None, - scheduler=None, - image_encoder=None, - image_processor=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, ) elif "T2V" in args.model_type: pipe = SkyReelsV2Pipeline( transformer=transformer, - text_encoder=None, - tokenizer=None, - vae=None, - scheduler=None, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, ) elif "DF" in args.model_type: pipe = SkyReelsV2DiffusionForcingPipeline( transformer=transformer, - text_encoder=None, + text_encoder=text_encoder, tokenizer=tokenizer, vae=vae, scheduler=scheduler, @@ -555,6 +555,6 @@ def get_args(): args.output_path, safe_serialization=True, max_shard_size="5GB", - push_to_hub=True, - repo_id=f"tolgacangoz/{args.model_type}-Diffusers", + #push_to_hub=True, + #repo_id=f"/{args.model_type}-Diffusers", ) From 9fd9dba89950b906098edc6b34831b3b1a6a7e40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 11:46:28 +0300 Subject: [PATCH 228/264] style --- scripts/convert_skyreelsv2_to_diffusers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index 0a437e316081..ef9f57b8e5b8 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -7,15 +7,15 @@ from accelerate import init_empty_weights from huggingface_hub import hf_hub_download from safetensors.torch import load_file -from transformers import AutoTokenizer, CLIPVisionModelWithProjection, AutoProcessor, UMT5EncoderModel +from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel from diffusers import ( AutoencoderKLWan, + FlowMatchUniPCMultistepScheduler, SkyReelsV2DiffusionForcingPipeline, SkyReelsV2ImageToVideoPipeline, SkyReelsV2Pipeline, SkyReelsV2Transformer3DModel, - FlowMatchUniPCMultistepScheduler, ) @@ -522,8 +522,7 @@ def get_args(): ) if "I2V" in args.model_type: - image_encoder = CLIPVisionModelWithProjection.from_pretrained( - "laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") pipe = SkyReelsV2ImageToVideoPipeline( transformer=transformer, @@ -555,6 +554,6 @@ def get_args(): args.output_path, safe_serialization=True, max_shard_size="5GB", - #push_to_hub=True, - #repo_id=f"/{args.model_type}-Diffusers", + # push_to_hub=True, + # repo_id=f"/{args.model_type}-Diffusers", ) From bc9eb42c5c71126b578cb717eb3a2f4072193630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 12:30:21 +0300 Subject: [PATCH 229/264] Update documentation for SkyReels-V2, introducing the Infinite-length Film Generative model, enhancing text-to-video generation examples, and updating model references throughout the API documentation. --- docs/source/en/api/pipelines/skyreels_v2.md | 233 ++++++++++++++------ 1 file changed, 170 insertions(+), 63 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index e0be071921d6..da7b4b1dd4fe 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -20,37 +20,51 @@ -# Wan2.1 +# SkyReels-V2: Infinite-length Film Generative model -[Wan2.1](https://files.alicdn.com/tpsservice/5c9de1c74de03972b7aa657e5a54756b.pdf) is a series of large diffusion transformer available in two versions, a high-performance 14B parameter model and a more accessible 1.3B version. Trained on billions of images and videos, it supports tasks like text-to-video (T2V) and image-to-video (I2V) while enabling features such as camera control and stylistic diversity. The Wan-VAE features better image data compression and a feature cache mechanism that encodes and decodes a video in chunks. To maintain continuity, features from previous chunks are cached and reused for processing subsequent chunks. This improves inference efficiency by reducing memory usage. Wan2.1 also uses a multilingual text encoder and the diffusion transformer models space and time relationships and text conditions with each time step to capture more complex video dynamics. +[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team. -You can find all the original Wan2.1 checkpoints under the [Wan-AI](https://huggingface.co/Wan-AI) organization. +*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).* + +You can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization. + +The following SkyReels-V2 models are supported in Diffusers: +- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers) +- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers) +- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers) +- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers) +- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers) +- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers) +- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers) +- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers) > [!TIP] -> Click on the Wan2.1 models in the right sidebar for more examples of video generation. +> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation. + +### Text-to-Video Generation The example below demonstrates how to generate a video from text optimized for memory or inference speed. - - + + Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. -The Wan2.1 text-to-video model below requires ~13GB of VRAM. +The SkyReels-V2 text-to-video model below requires ~13GB of VRAM. ```py # pip install ftfy import torch import numpy as np -from diffusers import AutoModel, WanPipeline +from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline from diffusers.quantizers import PipelineQuantizationConfig from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import export_to_video, load_image from transformers import UMT5EncoderModel -text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) -vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) -transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +text_encoder = UMT5EncoderModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) # group-offloading onload_device = torch.device("cuda") @@ -68,8 +82,8 @@ transformer.enable_group_offload( use_stream=True ) -pipeline = WanPipeline.from_pretrained( - "Wan-AI/Wan2.1-T2V-14B-Diffusers", +pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + "Skywork/SkyReels-V2-DF-14B-540P-Diffusers", vae=vae, transformer=transformer, text_encoder=text_encoder, @@ -78,29 +92,29 @@ pipeline = WanPipeline.from_pretrained( pipeline.to("cuda") prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, - num_frames=81, - guidance_scale=5.0, + num_frames=97, + guidance_scale=6.0, ).frames[0] -export_to_video(output, "output.mp4", fps=16) +export_to_video(output, "output.mp4", fps=24) ``` - + [Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. @@ -108,17 +122,17 @@ export_to_video(output, "output.mp4", fps=16) # pip install ftfy import torch import numpy as np -from diffusers import AutoModel, WanPipeline +from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline from diffusers.hooks.group_offloading import apply_group_offloading from diffusers.utils import export_to_video, load_image from transformers import UMT5EncoderModel -text_encoder = UMT5EncoderModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) -vae = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="vae", torch_dtype=torch.float32) -transformer = AutoModel.from_pretrained("Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +text_encoder = UMT5EncoderModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) +vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32) +transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) -pipeline = WanPipeline.from_pretrained( - "Wan-AI/Wan2.1-T2V-14B-Diffusers", +pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + "Skywork/SkyReels-V2-DF-14B-540P-Diffusers", vae=vae, transformer=transformer, text_encoder=text_encoder, @@ -133,33 +147,108 @@ pipeline.transformer = torch.compile( ) prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic +The camera rushes from far to near in a low-angle shot, +revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in +for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. +Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, +Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, +low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards """ output = pipeline( prompt=prompt, negative_prompt=negative_prompt, - num_frames=81, - guidance_scale=5.0, + num_frames=97, + guidance_scale=6.0, +).frames[0] +export_to_video(output, "output.mp4", fps=24) +``` + + + + +### First-Last-Frame-to-Video Generation + +The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame. + + + + +```python +import numpy as np +import torch +import torchvision.transforms.functional as TF +from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline +from diffusers.utils import export_to_video, load_image +from transformers import CLIPVisionModel + + +model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" +image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) +vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) +pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( + model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 +) +pipe.to("cuda") + +first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png") +last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png") + +def aspect_ratio_resize(image, pipe, max_area=720 * 1280): + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + return image, height, width + +def center_crop_resize(image, height, width): + # Calculate resize ratio to match first frame dimensions + resize_ratio = max(width / image.width, height / image.height) + + # Resize the image + width = round(image.width * resize_ratio) + height = round(image.height * resize_ratio) + size = [width, height] + image = TF.center_crop(image, size) + + return image, height, width + +first_frame, height, width = aspect_ratio_resize(first_frame, pipe) +if last_frame.size != first_frame.size: + last_frame, _, _ = center_crop_resize(last_frame, height, width) + +prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." + +output = pipe( + image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0 ).frames[0] -export_to_video(output, "output.mp4", fps=16) +export_to_video(output, "output.mp4", fps=24) ``` +### Any-to-Video Controllable Generation + +SkyReels-V2 supports various generation techniques which achieve controllable video generation. Some of the capabilities include: +- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]() +- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips) +- Inpainting and Outpainting +- Subject to Video (faces, object, characters, etc.) +- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.) + +The general rule of thumb to keep in mind when preparing inputs for the SkyReels-V2 pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color. + +The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals. + ## Notes -- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. +- SkyReels-V2 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
Show example code @@ -167,17 +256,17 @@ export_to_video(output, "output.mp4", fps=16) ```py # pip install ftfy import torch - from diffusers import AutoModel, WanPipeline - from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler + from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline + from diffusers import FlowMatchUniPCMultistepScheduler from diffusers.utils import export_to_video vae = AutoModel.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 + "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32 ) - pipeline = WanPipeline.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", vae=vae, torch_dtype=torch.bfloat16 + pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16 ) - pipeline.scheduler = UniPCMultistepScheduler.from_config( + pipeline.scheduler = FlowMatchUniPCMultistepScheduler.from_config( pipeline.scheduler.config, flow_shift=5.0 ) pipeline.to("cuda") @@ -189,19 +278,19 @@ export_to_video(output, "output.mp4", fps=16) # use "steamboat willie style" to trigger the LoRA prompt = """ - steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, - revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in - for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. - Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic + steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot, + revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in + for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. + Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic shadows and warm highlights. Medium composition, front view, low angle, with depth of field. """ output = pipeline( prompt=prompt, - num_frames=81, - guidance_scale=5.0, + num_frames=97, + guidance_scale=6.0, ).frames[0] - export_to_video(output, "output.mp4", fps=16) + export_to_video(output, "output.mp4", fps=24) ```
@@ -214,17 +303,17 @@ export_to_video(output, "output.mp4", fps=16) ```py # pip install ftfy import torch - from diffusers import WanPipeline, AutoModel + from diffusers import SkyReelsV2DiffusionForcingPipeline, AutoModel vae = AutoModel.from_single_file( - "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors" + "https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers/blob/main/split_files/vae/skyreels_v2_vae.safetensors" ) transformer = AutoModel.from_single_file( - "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors", + "https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers/blob/main/split_files/diffusion_models/skyreels_v2_df_1.3b_bf16.safetensors", torch_dtype=torch.bfloat16 ) - pipeline = WanPipeline.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( + "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, transformer=transformer, torch_dtype=torch.bfloat16 @@ -239,18 +328,36 @@ export_to_video(output, "output.mp4", fps=16) - Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images. -## WanPipeline +## SkyReelsV2DiffusionForcingPipeline + +[[autodoc]] SkyReelsV2DiffusionForcingPipeline + - all + - __call__ + +## SkyReelsV2DiffusionForcingImageToVideoPipeline + +[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline + - all + - __call__ + +## SkyReelsV2DiffusionForcingVideoToVideoPipeline + +[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline + - all + - __call__ + +## SkyReelsV2Pipeline -[[autodoc]] WanPipeline +[[autodoc]] SkyReelsV2Pipeline - all - __call__ -## WanImageToVideoPipeline +## SkyReelsV2ImageToVideoPipeline -[[autodoc]] WanImageToVideoPipeline +[[autodoc]] SkyReelsV2ImageToVideoPipeline - all - __call__ -## WanPipelineOutput +## SkyReelsV2PipelineOutput -[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput \ No newline at end of file +[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput \ No newline at end of file From de446ad117f4679c363906266b41a5faabeb7db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 12:35:04 +0300 Subject: [PATCH 230/264] Add SkyReelsV2Transformer3DModel and FlowMatchUniPCMultistepScheduler documentation, updating TOC and introducing new model and scheduler files. --- docs/source/en/_toctree.yml | 6 ++++ .../en/api/models/wan_transformer_3d copy.md | 30 +++++++++++++++++++ .../en/api/schedulers/flow_match_unipc.md | 18 +++++++++++ 3 files changed, 54 insertions(+) create mode 100644 docs/source/en/api/models/wan_transformer_3d copy.md create mode 100644 docs/source/en/api/schedulers/flow_match_unipc.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f13b7d54aec4..8686fd3565b3 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -333,6 +333,8 @@ title: TransformerTemporalModel - local: api/models/wan_transformer_3d title: WanTransformer3DModel + - local: api/models/skyreels_v2_transformer_3d + title: SkyReelsV2Transformer3DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -575,6 +577,8 @@ title: VisualCloze - local: api/pipelines/wan title: Wan + - local: api/pipelines/skyreels_v2 + title: SkyReels-V2 - local: api/pipelines/wuerstchen title: Wuerstchen title: Pipelines @@ -620,6 +624,8 @@ title: FlowMatchEulerDiscreteScheduler - local: api/schedulers/flow_match_heun_discrete title: FlowMatchHeunDiscreteScheduler + - local: api/schedulers/flow_match_unipc + title: FlowMatchUniPCMultistepScheduler - local: api/schedulers/heun title: HeunDiscreteScheduler - local: api/schedulers/ipndm diff --git a/docs/source/en/api/models/wan_transformer_3d copy.md b/docs/source/en/api/models/wan_transformer_3d copy.md new file mode 100644 index 000000000000..6d2ff4baffae --- /dev/null +++ b/docs/source/en/api/models/wan_transformer_3d copy.md @@ -0,0 +1,30 @@ + + +# SkyReelsV2Transformer3DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/Skywork-AI/SkyReels-V2) by the Skywork AI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import SkyReelsV2Transformer3DModel + +transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## SkyReelsV2Transformer3DModel + +[[autodoc]] SkyReelsV2Transformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/schedulers/flow_match_unipc.md b/docs/source/en/api/schedulers/flow_match_unipc.md new file mode 100644 index 000000000000..72f1b400b102 --- /dev/null +++ b/docs/source/en/api/schedulers/flow_match_unipc.md @@ -0,0 +1,18 @@ + + +# FlowMatchUniPCMultistepScheduler + +`FlowMatchUniPCMultistepScheduler` is based on the flow-matching sampling introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). + +## FlowMatchUniPCMultistepScheduler +[[autodoc]] FlowMatchUniPCMultistepScheduler From f2f66137906f61201b6e501565087299a8dc431a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 12:36:32 +0300 Subject: [PATCH 231/264] style --- docs/source/en/_toctree.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8686fd3565b3..0bb85d4a0b74 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -30,7 +30,8 @@ - local: using-diffusers/push_to_hub title: Push files to the Hub title: Load pipelines and adapters -- sections: +- isExpanded: false + sections: - local: tutorials/using_peft_for_inference title: LoRA - local: using-diffusers/ip_adapter @@ -44,7 +45,6 @@ - local: using-diffusers/textual_inversion_inference title: Textual inversion title: Adapters - isExpanded: false - sections: - local: using-diffusers/unconditional_image_generation title: Unconditional image generation @@ -325,6 +325,8 @@ title: SanaTransformer2DModel - local: api/models/sd3_transformer2d title: SD3Transformer2DModel + - local: api/models/skyreels_v2_transformer_3d + title: SkyReelsV2Transformer3DModel - local: api/models/stable_audio_transformer title: StableAudioDiTModel - local: api/models/transformer2d @@ -333,8 +335,6 @@ title: TransformerTemporalModel - local: api/models/wan_transformer_3d title: WanTransformer3DModel - - local: api/models/skyreels_v2_transformer_3d - title: SkyReelsV2Transformer3DModel title: Transformers - sections: - local: api/models/stable_cascade_unet @@ -519,6 +519,8 @@ title: Semantic Guidance - local: api/pipelines/shap_e title: Shap-E + - local: api/pipelines/skyreels_v2 + title: SkyReels-V2 - local: api/pipelines/stable_audio title: Stable Audio - local: api/pipelines/stable_cascade @@ -577,8 +579,6 @@ title: VisualCloze - local: api/pipelines/wan title: Wan - - local: api/pipelines/skyreels_v2 - title: SkyReels-V2 - local: api/pipelines/wuerstchen title: Wuerstchen title: Pipelines From b707a6c9ba25a2c1b6c6ed8e262ef092e96bc40d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 14:09:08 +0300 Subject: [PATCH 232/264] Update documentation for SkyReelsV2DiffusionForcingPipeline to correct flow matching scheduler parameter for I2V from 3.0 to 5.0, ensuring clarity in usage examples. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8649af1e65e5..3029787b2893 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -72,7 +72,7 @@ ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) - >>> shift = 8.0 # 8.0 for T2V, 3.0 for I2V + >>> shift = 8.0 # 8.0 for T2V, 5.0 for I2V >>> pipe.scheduler = FlowMatchUniPCMultistepScheduler.from_config(pipe.scheduler.config, shift=shift) >>> pipe = pipe.to("cuda") >>> pipe.transformer.set_ar_attention(causal_block_size=5) @@ -619,7 +619,7 @@ def __call__( max_sequence_length (`int`, *optional*, defaults to `512`): The maximum sequence length of the prompt. shift (`float`, *optional*, defaults to `8.0`): - Flow matching scheduler parameter (**3.0 for I2V**, **8.0 for T2V**) + Flow matching scheduler parameter (**5.0 for I2V**, **8.0 for T2V**) overlap_history (`int`, *optional*, defaults to `None`): Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes short video generation mode, and no overlap is applied. 17 and 37 are recommended to set. From dc7326765ed928dac7a6ed528d3e0f69b96081af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 8 Jun 2025 20:03:36 +0300 Subject: [PATCH 233/264] Add documentation for causal_block_size parameter in SkyReelsV2DF pipelines, clarifying its role in asynchronous inference. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 1 + .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py | 1 + .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 3029787b2893..8061b30bd3f4 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -636,6 +636,7 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 5dce13a66068..ecdf1abd3273 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -692,6 +692,7 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 07e44b94b6dc..feaaae4ee4ed 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -709,6 +709,7 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video From c2aab89c89f7e13fcb02c249b81c1c0ee8f31a17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 9 Jun 2025 10:24:43 +0300 Subject: [PATCH 234/264] Simplify min_ar_step calculation in SkyReelsV2DiffusionForcingPipeline to improve clarity. --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 8061b30bd3f4..3ed1c8224d55 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -441,9 +441,7 @@ def generate_timestep_matrix( num_frames_block = num_frames // causal_block_size base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block + min_ar_step = len(step_template) / base_num_frames_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") From 7ce7a96be6961e8301c34ee8ea8938055ac454d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 9 Jun 2025 10:26:34 +0300 Subject: [PATCH 235/264] style and fix-copies --- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 8 +++----- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 8 +++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 3ed1c8224d55..795368a27aeb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -634,8 +634,8 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): - The number of frames in each block/chunk. - Recommended when using asynchronous inference (when ar_step > 0) + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index ecdf1abd3273..4d1c1bcc57e9 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -486,9 +486,7 @@ def generate_timestep_matrix( num_frames_block = num_frames // causal_block_size base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block + min_ar_step = len(step_template) / base_num_frames_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") @@ -692,8 +690,8 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): - The number of frames in each block/chunk. - Recommended when using asynchronous inference (when ar_step > 0) + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index feaaae4ee4ed..9d94d2d6d464 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -511,9 +511,7 @@ def generate_timestep_matrix( num_frames_block = num_frames // causal_block_size base_num_frames_block = base_num_frames // causal_block_size if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block + min_ar_step = len(step_template) / base_num_frames_block if ar_step < min_ar_step: raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") @@ -709,8 +707,8 @@ def __call__( sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance. causal_block_size (`int`, *optional*, defaults to `None`): - The number of frames in each block/chunk. - Recommended when using asynchronous inference (when ar_step > 0) + The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step > + 0) fps (`int`, *optional*, defaults to `24`): Frame rate of the generated video From 32a6520f0a1b511160559e312e342557abdd6fd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 9 Jun 2025 10:47:15 +0300 Subject: [PATCH 236/264] style --- .../pipeline_skyreels_v2_diffusion_forcing.py | 24 +++++++++---------- ...eline_skyreels_v2_diffusion_forcing_i2v.py | 24 +++++++++---------- ...eline_skyreels_v2_diffusion_forcing_v2v.py | 24 +++++++++---------- 3 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 795368a27aeb..f85e258e0f5e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -438,12 +438,12 @@ def generate_timestep_matrix( step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // causal_block_size - base_num_frames_block = base_num_frames // causal_block_size - if base_num_frames_block < num_frames_block: - min_ar_step = len(step_template) / base_num_frames_block + num_blocks = num_frames // causal_block_size + base_num_blocks = base_num_frames // causal_block_size + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks if ar_step < min_ar_step: - raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") step_template = torch.cat( [ @@ -452,13 +452,13 @@ def generate_timestep_matrix( torch.tensor([0], dtype=torch.int64, device=step_template.device), ] ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) + pre_row = torch.zeros(num_blocks, dtype=torch.long) if num_pre_ready > 0: pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): + new_row = torch.zeros(num_blocks, dtype=torch.long) + for i in range(num_blocks): if i == 0 or pre_row[i - 1] >= ( num_iterations - 1 ): # the first frame or the last frame is completely denoised @@ -475,18 +475,18 @@ def generate_timestep_matrix( pre_row = new_row # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block + terminal_flag = base_num_blocks if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) update_mask = update_mask[0] update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + if terminal_flag < num_blocks and curr_mask[terminal_flag]: terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) step_update_mask = torch.stack(update_mask, dim=0) step_index = torch.stack(step_index, dim=0) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 4d1c1bcc57e9..1c5c29bbc543 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -483,12 +483,12 @@ def generate_timestep_matrix( step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // causal_block_size - base_num_frames_block = base_num_frames // causal_block_size - if base_num_frames_block < num_frames_block: - min_ar_step = len(step_template) / base_num_frames_block + num_blocks = num_frames // causal_block_size + base_num_blocks = base_num_frames // causal_block_size + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks if ar_step < min_ar_step: - raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") step_template = torch.cat( [ @@ -497,13 +497,13 @@ def generate_timestep_matrix( torch.tensor([0], dtype=torch.int64, device=step_template.device), ] ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) + pre_row = torch.zeros(num_blocks, dtype=torch.long) if num_pre_ready > 0: pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): + new_row = torch.zeros(num_blocks, dtype=torch.long) + for i in range(num_blocks): if i == 0 or pre_row[i - 1] >= ( num_iterations - 1 ): # the first frame or the last frame is completely denoised @@ -520,18 +520,18 @@ def generate_timestep_matrix( pre_row = new_row # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block + terminal_flag = base_num_blocks if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) update_mask = update_mask[0] update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + if terminal_flag < num_blocks and curr_mask[terminal_flag]: terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) step_update_mask = torch.stack(update_mask, dim=0) step_index = torch.stack(step_index, dim=0) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 9d94d2d6d464..91713d3a4c2d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -508,12 +508,12 @@ def generate_timestep_matrix( step_matrix, step_index = [], [] update_mask, valid_interval = [], [] num_iterations = len(step_template) + 1 - num_frames_block = num_frames // causal_block_size - base_num_frames_block = base_num_frames // causal_block_size - if base_num_frames_block < num_frames_block: - min_ar_step = len(step_template) / base_num_frames_block + num_blocks = num_frames // causal_block_size + base_num_blocks = base_num_frames // causal_block_size + if base_num_blocks < num_blocks: + min_ar_step = len(step_template) / base_num_blocks if ar_step < min_ar_step: - raise ValueError(f"ar_step should be at least {math.ceil(min_ar_step)} in your setting") + raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting") step_template = torch.cat( [ @@ -522,13 +522,13 @@ def generate_timestep_matrix( torch.tensor([0], dtype=torch.int64, device=step_template.device), ] ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) + pre_row = torch.zeros(num_blocks, dtype=torch.long) if num_pre_ready > 0: pre_row[: num_pre_ready // causal_block_size] = num_iterations while not torch.all(pre_row >= (num_iterations - 1)): - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): + new_row = torch.zeros(num_blocks, dtype=torch.long) + for i in range(num_blocks): if i == 0 or pre_row[i - 1] >= ( num_iterations - 1 ): # the first frame or the last frame is completely denoised @@ -545,18 +545,18 @@ def generate_timestep_matrix( pre_row = new_row # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block + terminal_flag = base_num_blocks if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) + idx_sequence = torch.arange(num_blocks, dtype=torch.int64) update_mask = update_mask[0] update_mask_idx = idx_sequence[update_mask] last_update_idx = update_mask_idx[-1].item() terminal_flag = last_update_idx + 1 for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: + if terminal_flag < num_blocks and curr_mask[terminal_flag]: terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) + valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag)) step_update_mask = torch.stack(update_mask, dim=0) step_index = torch.stack(step_index, dim=0) From 59c4057049cd3f0aad1afcc1cd0d7f17cebe76e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 12 Jun 2025 08:22:24 +0300 Subject: [PATCH 237/264] Add documentation for SkyReelsV2Transformer3DModel Introduced a new markdown file detailing the SkyReelsV2Transformer3DModel, including usage instructions and model output specifications. --- .../{wan_transformer_3d copy.md => skyreels_v2_transformer_3d.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/source/en/api/models/{wan_transformer_3d copy.md => skyreels_v2_transformer_3d.md} (100%) diff --git a/docs/source/en/api/models/wan_transformer_3d copy.md b/docs/source/en/api/models/skyreels_v2_transformer_3d.md similarity index 100% rename from docs/source/en/api/models/wan_transformer_3d copy.md rename to docs/source/en/api/models/skyreels_v2_transformer_3d.md From 9b026e452c1ff580c13b7ff39b5f7a452d9bfdea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 12 Jun 2025 19:55:17 +0300 Subject: [PATCH 238/264] Update test configurations for SkyReelsV2 pipelines - Adjusted `in_channels` from 36 to 16 in `test_skyreels_v2_df_image_to_video.py`. - Added new parameters: `overlap_history`, `num_frames`, and `base_num_frames` in `test_skyreels_v2_df_video_to_video.py`. - Updated expected output shape in video tests from (17, 3, 16, 16) to (41, 3, 16, 16). --- .../skyreels_v2/test_skyreels_v2_df_image_to_video.py | 4 ++-- .../skyreels_v2/test_skyreels_v2_df_video_to_video.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py index 2053ad21970a..471532cdf80c 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py @@ -76,7 +76,7 @@ def get_dummy_components(self): patch_size=(1, 2, 2), num_attention_heads=2, attention_head_dim=12, - in_channels=36, + in_channels=16, out_channels=16, text_dim=32, freq_dim=256, @@ -167,7 +167,7 @@ def get_dummy_components(self): patch_size=(1, 2, 2), num_attention_heads=2, attention_head_dim=12, - in_channels=36, + in_channels=16, out_channels=16, text_dim=32, freq_dim=256, diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py index e252ef3e1b48..e5c5101a5a10 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py @@ -114,6 +114,9 @@ def get_dummy_inputs(self, device, seed=0): "width": 16, "max_sequence_length": 16, "output_type": "pt", + "overlap_history": 1, + "num_frames": 20, + "base_num_frames": 10, } return inputs @@ -129,8 +132,8 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (17, 3, 16, 16)) - expected_video = torch.randn(17, 3, 16, 16) + self.assertEqual(generated_video.shape, (41, 3, 16, 16)) + expected_video = torch.randn(41, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10) From 4c89187b776abb59b17140690d3c0e145d57d65d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Thu, 12 Jun 2025 20:29:58 +0300 Subject: [PATCH 239/264] Refines SkyReelsV2DF test parameters --- .../skyreels_v2/test_skyreels_v2_df_image_to_video.py | 4 ++-- .../skyreels_v2/test_skyreels_v2_df_video_to_video.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py index 471532cdf80c..ab15ee984203 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py @@ -113,7 +113,7 @@ def get_dummy_inputs(self, device, seed=0): "width": image_width, "generator": generator, "num_inference_steps": 2, - "guidance_scale": 6.0, + "guidance_scale": 5.0, "num_frames": 9, "max_sequence_length": 16, "output_type": "pt", @@ -207,7 +207,7 @@ def get_dummy_inputs(self, device, seed=0): "width": image_width, "generator": generator, "num_inference_steps": 2, - "guidance_scale": 6.0, + "guidance_scale": 5.0, "num_frames": 9, "max_sequence_length": 16, "output_type": "pt", diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py index e5c5101a5a10..b4e0d25c1233 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py @@ -114,9 +114,9 @@ def get_dummy_inputs(self, device, seed=0): "width": 16, "max_sequence_length": 16, "output_type": "pt", - "overlap_history": 1, - "num_frames": 20, - "base_num_frames": 10, + "overlap_history": 3, + "num_frames": 9, + "base_num_frames": 5, } return inputs @@ -132,8 +132,8 @@ def test_inference(self): video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (41, 3, 16, 16)) - expected_video = torch.randn(41, 3, 16, 16) + self.assertEqual(generated_video.shape, (21, 3, 16, 16)) + expected_video = torch.randn(21, 3, 16, 16) max_diff = np.abs(generated_video - expected_video).max() self.assertLessEqual(max_diff, 1e10) From 6aec002ead95662f5023c9face472659738c10b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Fri, 13 Jun 2025 15:35:55 +0300 Subject: [PATCH 240/264] Update src/diffusers/models/modeling_outputs.py Co-authored-by: Aryan --- src/diffusers/models/modeling_outputs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/modeling_outputs.py b/src/diffusers/models/modeling_outputs.py index 96e86abff491..0120a34d9052 100644 --- a/src/diffusers/models/modeling_outputs.py +++ b/src/diffusers/models/modeling_outputs.py @@ -26,8 +26,6 @@ class Transformer2DModelOutput(BaseOutput): sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability distributions for the unnoised latent pixels. - debug_tensors (`Optional[Dict[str, Any]]`, *optional*): - A dictionary containing intermediate tensors and their shapes for debugging purposes. """ sample: "torch.Tensor" # noqa: F821 From 8fcc7f0131c7a208e1b87d2323315c42c5d34c84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 16:28:42 +0300 Subject: [PATCH 241/264] Refactor `grid_sizes` processing by using already-calculated post-patch parameters to simplify --- .../transformers/transformer_skyreels_v2.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index eaa138873ea5..0471189495c5 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -477,22 +477,20 @@ def forward( rotary_emb = self.rope(hidden_states) hidden_states = self.patch_embedding(hidden_states) - grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long) + hidden_states = hidden_states.flatten(2).transpose(1, 2) if self.config.flag_causal_attention: - frame_num, height, width = grid_sizes - block_num = frame_num // self.config.num_frame_per_block + block_num = post_patch_num_frames // self.config.num_frame_per_block range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave( self.config.num_frame_per_block ) causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f - causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1) - causal_mask = causal_mask.repeat(1, height, width, 1, height, width) - causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width) + causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1) + causal_mask = causal_mask.repeat(1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width) + causal_mask = causal_mask.reshape(post_patch_num_frames * post_patch_height * post_patch_width, + post_patch_num_frames * post_patch_height * post_patch_width) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) - hidden_states = hidden_states.flatten(2).transpose(1, 2) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( timestep, encoder_hidden_states, encoder_hidden_states_image ) @@ -531,8 +529,8 @@ def forward( b, f = timestep.shape temb = temb.view(b, f, 1, 1, -1) timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) - temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3) - timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3) + temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3) + timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() for i, block in enumerate(self.blocks): From b5df1758f6a05ef2ce16d5d40cf0cb42aa1bedf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:32:10 +0300 Subject: [PATCH 242/264] Update docs/source/en/api/pipelines/skyreels_v2.md Co-authored-by: Aryan --- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index da7b4b1dd4fe..191e939a9822 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -295,7 +295,7 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p -- [`WanTransformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`]. +- [`SkyReelsV2Transformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`].
Show example code From c446fe5fc73c3681c4e404b5b233fbd435879268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Fri, 13 Jun 2025 16:53:20 +0300 Subject: [PATCH 243/264] Refactor parameter naming for diffusion forcing in SkyReelsV2 pipelines - Changed `flag_df` to `enable_diffusion_forcing` for clarity in the SkyReelsV2Transformer3DModel and associated pipelines. - Updated all relevant method calls to reflect the new parameter name. --- .../models/transformers/transformer_skyreels_v2.py | 6 +++--- .../skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py | 8 ++++---- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 8 ++++---- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 0471189495c5..c78ec091210c 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -448,7 +448,7 @@ def forward( timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, - flag_df: bool = False, + enable_diffusion_forcing: bool = False, fps: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -517,7 +517,7 @@ def forward( fps_emb = self.fps_embedding(fps).float() timestep_proj = timestep_proj.to(fps_emb.dtype) self.fps_projection.to(fps_emb.dtype) - if flag_df: + if enable_diffusion_forcing: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat( timestep.shape[1], 1, 1 ) @@ -525,7 +525,7 @@ def forward( timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)) timestep_proj = timestep_proj.to(hidden_states.dtype) - if flag_df: + if enable_diffusion_forcing: b, f = timestep.shape temb = temb.view(b, f, 1, 1, -1) timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index f85e258e0f5e..a850bb249dd0 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -782,7 +782,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -792,7 +792,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -911,7 +911,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -921,7 +921,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 1c5c29bbc543..60539899b46e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -858,7 +858,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -868,7 +868,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -1008,7 +1008,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -1018,7 +1018,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 91713d3a4c2d..0cd16aafbbda 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -883,7 +883,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, @@ -893,7 +893,7 @@ def __call__( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=negative_prompt_embeds, - flag_df=True, + enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, From 7f13e1d89000d08f4b1cab6408ee8e677bd680b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 12:07:51 +0300 Subject: [PATCH 244/264] Revert _toctree.yml to adjust section expansion states --- docs/source/en/_toctree.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 0bb85d4a0b74..223871ea7e38 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -30,8 +30,7 @@ - local: using-diffusers/push_to_hub title: Push files to the Hub title: Load pipelines and adapters -- isExpanded: false - sections: +- sections: - local: tutorials/using_peft_for_inference title: LoRA - local: using-diffusers/ip_adapter @@ -45,6 +44,7 @@ - local: using-diffusers/textual_inversion_inference title: Textual inversion title: Adapters + isExpanded: false - sections: - local: using-diffusers/unconditional_image_generation title: Unconditional image generation From 6931366db5ea004ed2a55cac3e54b662b7c8ae0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 12:08:40 +0300 Subject: [PATCH 245/264] style --- .../models/transformers/transformer_skyreels_v2.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index c78ec091210c..e3068c53b38e 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -486,9 +486,13 @@ def forward( ) causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1) - causal_mask = causal_mask.repeat(1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width) - causal_mask = causal_mask.reshape(post_patch_num_frames * post_patch_height * post_patch_width, - post_patch_num_frames * post_patch_height * post_patch_width) + causal_mask = causal_mask.repeat( + 1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width + ) + causal_mask = causal_mask.reshape( + post_patch_num_frames * post_patch_height * post_patch_width, + post_patch_num_frames * post_patch_height * post_patch_width, + ) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( From c42f98fea27b7f24201ae0e24ae2faf3272c1cae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= <46008593+tolgacangoz@users.noreply.github.com> Date: Sat, 14 Jun 2025 12:13:21 +0300 Subject: [PATCH 246/264] Update docs/source/en/api/models/skyreels_v2_transformer_3d.md Co-authored-by: YiYi Xu --- docs/source/en/api/models/skyreels_v2_transformer_3d.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/api/models/skyreels_v2_transformer_3d.md b/docs/source/en/api/models/skyreels_v2_transformer_3d.md index 6d2ff4baffae..c1c8c2c7bcce 100644 --- a/docs/source/en/api/models/skyreels_v2_transformer_3d.md +++ b/docs/source/en/api/models/skyreels_v2_transformer_3d.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> # SkyReelsV2Transformer3DModel -A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/Skywork-AI/SkyReels-V2) by the Skywork AI. +A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI. The model can be loaded with the following code snippet. From 9a5b93d0e780f394e2a2c732d8dd6f085c45c78c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 12:37:55 +0300 Subject: [PATCH 247/264] Add copying label to SkyReelsV2ImageEmbedding from WanImageEmbedding. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index e3068c53b38e..15d617e75aaa 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -127,6 +127,7 @@ def set_ar_attention(self): self._flag_ar_attention = True +# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding class SkyReelsV2ImageEmbedding(torch.nn.Module): def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): super().__init__() From fdabf039d1f091605c33d0da444c1846b3b8418d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 14:41:57 +0300 Subject: [PATCH 248/264] Refactor transformer block processing in SkyReelsV2Transformer3DModel - Ensured proper handling of hidden states during both gradient checkpointing and standard processing. --- .../transformers/transformer_skyreels_v2.py | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 15d617e75aaa..fcd792d33299 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -505,17 +505,6 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) - # 4. Transformer blocks - if torch.is_grad_enabled() and self.gradient_checkpointing: - for i, block in enumerate(self.blocks): - hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - timestep_proj, - rotary_emb, - causal_mask if self.config.flag_causal_attention else None, - ) if self.config.inject_sample_info: fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device) @@ -538,14 +527,27 @@ def forward( timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(1, 3) timestep_proj = timestep_proj.transpose(1, 2).contiguous() - for i, block in enumerate(self.blocks): - hidden_states = block( - hidden_states, - encoder_hidden_states, - timestep_proj, - rotary_emb, - causal_mask if self.config.flag_causal_attention else None, - ) + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + causal_mask if self.config.flag_causal_attention else None, + ) + + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + causal_mask if self.config.flag_causal_attention else None, + ) if temb.dim() == 2: shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) From 46b32ad089886be86e5d530171c17799e5d060f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 15:05:19 +0300 Subject: [PATCH 249/264] Update SkyReels V2 documentation to remove VRAM requirement and streamline imports - Removed the mention of ~13GB VRAM requirement for the SkyReels-V2 model. - Simplified import statements by removing unused `load_image` import. --- docs/source/en/api/pipelines/skyreels_v2.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 191e939a9822..2a438a7ec142 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -50,16 +50,13 @@ The example below demonstrates how to generate a video from text optimized for m Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques. -The SkyReels-V2 text-to-video model below requires ~13GB of VRAM. - ```py # pip install ftfy import torch import numpy as np from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline -from diffusers.quantizers import PipelineQuantizationConfig from diffusers.hooks.group_offloading import apply_group_offloading -from diffusers.utils import export_to_video, load_image +from diffusers.utils import export_to_video from transformers import UMT5EncoderModel text_encoder = UMT5EncoderModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) @@ -124,7 +121,7 @@ import torch import numpy as np from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline from diffusers.hooks.group_offloading import apply_group_offloading -from diffusers.utils import export_to_video, load_image +from diffusers.utils import export_to_video from transformers import UMT5EncoderModel text_encoder = UMT5EncoderModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) From 194a9cc37815ac37eca71ba27525c844d4ac8133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 18:12:17 +0300 Subject: [PATCH 250/264] Add SkyReelsV2LoraLoaderMixin for loading and managing LoRA layers in SkyReelsV2Transformer3DModel - Introduced SkyReelsV2LoraLoaderMixin class to handle loading, saving, and fusing of LoRA weights specific to the SkyReelsV2 model. - Implemented methods for state dict management, including compatibility checks for various LoRA formats. - Enhanced functionality for loading weights with options for low CPU memory usage and hotswapping. - Added detailed docstrings for clarity on parameters and usage. --- src/diffusers/loaders/lora_pipeline.py | 373 +++++++++++++++++++++++++ 1 file changed, 373 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 189a9ceba541..5a8ab2b0022d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5071,6 +5071,379 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) +# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin with WanTransformer3DModel->SkyReelsV2Transformer3DModel +class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + if any(k.startswith("diffusion_model.") for k in state_dict): + state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict) + elif any(k.startswith("lora_unet_") for k in state_dict): + state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + return state_dict + + @classmethod + def _maybe_expand_t2v_lora_for_i2v( + cls, + transformer: torch.nn.Module, + state_dict, + ): + if transformer.config.image_dim is None: + return state_dict + + target_device = transformer.device + + if any(k.startswith("transformer.blocks.") for k in state_dict): + num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k}) + is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict) + has_bias = any(".lora_B.bias" in k for k in state_dict) + + if is_i2v_lora: + return state_dict + + for i in range(num_blocks): + for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]): + # These keys should exist if the block `i` was part of the T2V LoRA. + ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight" + ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight" + + if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict: + continue + + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device + ) + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like( + state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device + ) + + # If the original LoRA had biases (indicated by has_bias) + # AND the specific reference bias key exists for this block. + + ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias" + if has_bias and ref_key_lora_B_bias in state_dict: + ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias] + state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like( + ref_lora_B_bias_tensor, + device=target_device, + ) + + return state_dict + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers + state_dict = self._maybe_expand_t2v_lora_for_i2v( + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + state_dict=state_dict, + ) + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SkyReelsV2Transformer3DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + """ + super().unfuse_lora(components=components, **kwargs) + class CogView4LoraLoaderMixin(LoraBaseMixin): r""" From 0e9acfff75f0853365aa3317f7967b50aac53443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 19:38:25 +0300 Subject: [PATCH 251/264] Update SkyReelsV2 documentation and loader mixin references - Corrected the documentation to reference the new `SkyReelsV2LoraLoaderMixin` for loading LoRA weights. - Updated comments in the `SkyReelsV2LoraLoaderMixin` class to reflect changes in model references from `WanTransformer3DModel` to `SkyReelsV2Transformer3DModel`. --- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- src/diffusers/loaders/lora_pipeline.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index 2a438a7ec142..f284ea9e391c 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -245,7 +245,7 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p ## Notes -- SkyReels-V2 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`]. +- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
Show example code diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 5a8ab2b0022d..bb96289b3bdc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5071,7 +5071,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) -# Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin with WanTransformer3DModel->SkyReelsV2Transformer3DModel + class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`SkyReelsV2Transformer3DModel`]. @@ -5291,7 +5291,7 @@ def load_lora_weights( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): From 4dbe83432fda30d9fee5c505a98f6c91ae19eef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 19:48:27 +0300 Subject: [PATCH 252/264] Enhance SkyReelsV2 integration by adding SkyReelsV2LoraLoaderMixin references - Added `SkyReelsV2LoraLoaderMixin` to the documentation and loader imports for improved LoRA weight management. - Updated multiple pipeline classes to inherit from `SkyReelsV2LoraLoaderMixin` instead of `WanLoraLoaderMixin`. --- docs/source/en/api/loaders/lora.md | 11 ++++++----- docs/source/en/api/pipelines/skyreels_v2.md | 2 +- src/diffusers/loaders/__init__.py | 2 ++ .../pipelines/skyreels_v2/pipeline_skyreels_v2.py | 4 ++-- .../pipeline_skyreels_v2_diffusion_forcing.py | 4 ++-- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 4 ++-- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 4 ++-- .../pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py | 4 ++-- 8 files changed, 19 insertions(+), 16 deletions(-) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 9271999d63a6..ddec413bec8b 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -26,6 +26,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video). - [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2). - [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan). +- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2). - [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream) @@ -88,6 +89,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin +## SkyReelsV2LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin + ## AmusedLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin @@ -98,8 +103,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## LoraBaseMixin -[[autodoc]] loaders.lora_base.LoraBaseMixin - -## WanLoraLoaderMixin - -[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin \ No newline at end of file +[[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index f284ea9e391c..de50d9a216dc 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -292,7 +292,7 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
-- [`SkyReelsV2Transformer3DModel`] and [`AutoencoderKLWan`] supports loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`]. +- [`SkyReelsV2Transformer3DModel`] and [`AutoencoderKLWan`] support loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`].
Show example code diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 84c6d9f32c66..fe45c739a3ca 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -78,6 +78,7 @@ def text_encoder_attn_modules(text_encoder): "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", + "SkyReelsV2LoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["ip_adapter"] = [ @@ -120,6 +121,7 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + SkyReelsV2LoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index aed899325e32..aa0602e5b417 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -20,7 +20,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import WanLoraLoaderMixin +from ...loaders import SkyReelsV2LoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring @@ -105,7 +105,7 @@ def prompt_clean(text): return text -class SkyReelsV2Pipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): r""" Pipeline for Text-to-Video (t2v) generation using SkyReels-V2. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index a850bb249dd0..618b4ed9fd6d 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -23,7 +23,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import WanLoraLoaderMixin +from ...loaders import SkyReelsV2LoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring @@ -126,7 +126,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): """ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 60539899b46e..577c189c60d8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -28,7 +28,7 @@ from diffusers.video_processor import VideoProcessor from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import WanLoraLoaderMixin +from ...loaders import SkyReelsV2LoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring @@ -132,7 +132,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): """ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 0cd16aafbbda..1f92e039a282 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -25,7 +25,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import WanLoraLoaderMixin +from ...loaders import SkyReelsV2LoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring @@ -188,7 +188,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): """ Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing. diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 48d69df74e20..587fdda9c94e 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -22,7 +22,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput -from ...loaders import WanLoraLoaderMixin +from ...loaders import SkyReelsV2LoraLoaderMixin from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel from ...schedulers import FlowMatchUniPCMultistepScheduler from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring @@ -125,7 +125,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin): +class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin): r""" Pipeline for Image-to-Video (i2v) generation using SkyReels-V2. From b6675c0aef5a55f547308625ed04007f16d23f94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 19:51:22 +0300 Subject: [PATCH 253/264] Update SkyReelsV2 model references in documentation - Replaced placeholder model paths with actual paths for SkyReels-V2 models in multiple pipeline files. - Ensured consistency across the documentation for loading models in the SkyReelsV2 pipelines. --- .../pipelines/skyreels_v2/pipeline_skyreels_v2.py | 8 ++++---- .../pipeline_skyreels_v2_diffusion_forcing.py | 10 +++++----- .../pipeline_skyreels_v2_diffusion_forcing_i2v.py | 10 +++++----- .../pipeline_skyreels_v2_diffusion_forcing_v2v.py | 10 +++++----- .../pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py | 10 +++++----- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index aa0602e5b417..6f6a5b6c8085 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -56,15 +56,15 @@ >>> # Load the pipeline >>> # Available models: - >>> # - /SkyReels-V2-T2V-14B-540P-Diffusers - >>> # - /SkyReels-V2-T2V-14B-720P-Diffusers + >>> # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-T2V-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2Pipeline.from_pretrained( - ... "/SkyReels-V2-T2V-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 618b4ed9fd6d..2c5283f6db37 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -59,16 +59,16 @@ >>> # Load the pipeline >>> # Available models: - >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-720P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 577c189c60d8..2938522b660f 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -63,16 +63,16 @@ >>> # Load the pipeline >>> # Available models: - >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-720P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 1f92e039a282..eec3e721ed5c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -61,16 +61,16 @@ >>> # Load the pipeline >>> # Available models: - >>> # - /SkyReels-V2-DF-1.3B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-540P-Diffusers - >>> # - /SkyReels-V2-DF-14B-720P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-DF-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index 587fdda9c94e..a06eebc4bb97 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -59,16 +59,16 @@ >>> # Load the pipeline >>> # Available models: - >>> # - /SkyReels-V2-I2V-1.3B-540P-Diffusers - >>> # - /SkyReels-V2-I2V-14B-540P-Diffusers - >>> # - /SkyReels-V2-I2V-14B-720P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers + >>> # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers >>> vae = AutoencoderKLWan.from_pretrained( - ... "/SkyReels-V2-I2V-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", ... subfolder="vae", ... torch_dtype=torch.float32, ... ) >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained( - ... "/SkyReels-V2-I2V-14B-720P-Diffusers", + ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers", ... vae=vae, ... torch_dtype=torch.bfloat16, ... ) From 7cb257c0084e1b1ad68bf3cc0f43e3f7edede37c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 19:51:59 +0300 Subject: [PATCH 254/264] style --- src/diffusers/loaders/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index fe45c739a3ca..bca7b8737e31 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -118,10 +118,10 @@ def text_encoder_attn_modules(text_encoder): Mochi1LoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, + SkyReelsV2LoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, - SkyReelsV2LoraLoaderMixin, ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin From 023336f922996c6f853684386e474fe54a91f2ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sat, 14 Jun 2025 19:55:04 +0300 Subject: [PATCH 255/264] fix-copies --- src/diffusers/loaders/lora_pipeline.py | 28 ++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a9714de77de1..23f487001f4a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -5635,7 +5635,14 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -5655,6 +5662,9 @@ def load_lora_into_transformer( weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -5667,6 +5677,7 @@ def load_lora_into_transformer( state_dict, network_alphas=None, adapter_name=adapter_name, + metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, @@ -5682,9 +5693,10 @@ def save_lora_weights( weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, + transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" - Save the LoRA parameters corresponding to the UNet and text encoder. + Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): @@ -5701,14 +5713,21 @@ def save_lora_weights( `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} + lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if transformer_lora_adapter_metadata is not None: + lora_adapter_metadata.update( + _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + ) # Save the model cls.write_lora_layers( @@ -5718,6 +5737,7 @@ def save_lora_weights( weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, + lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora From 4fd94bfac552ed02852ae506bfa447d7c3fcf8ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Sun, 15 Jun 2025 17:15:08 +0300 Subject: [PATCH 256/264] Refactor `fps_projection` in `SkyReelsV2Transformer3DModel` - Replaced the sequential linear layers for `fps_projection` with a `FeedForward` layer using `SiLU` activation for better integration. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index fcd792d33299..ad81f19cdf46 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -439,9 +439,7 @@ def __init__( if inject_sample_info: self.fps_embedding = nn.Embedding(2, inner_dim) - self.fps_projection = nn.Sequential( - nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6) - ) + self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="silu") def forward( self, From 74c2209179574aac82a9c846beb015bade1097ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 08:32:39 +0300 Subject: [PATCH 257/264] Update docs --- docs/source/en/api/pipelines/skyreels_v2.md | 135 ++------------------ 1 file changed, 12 insertions(+), 123 deletions(-) diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md index de50d9a216dc..14f37a2916ce 100644 --- a/docs/source/en/api/pipelines/skyreels_v2.md +++ b/docs/source/en/api/pipelines/skyreels_v2.md @@ -86,83 +86,22 @@ pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( text_encoder=text_encoder, torch_dtype=torch.bfloat16 ) -pipeline.to("cuda") - -prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic -shadows and warm highlights. Medium composition, front view, low angle, with depth of field. -""" -negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, -misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards -""" - -output = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - num_frames=97, - guidance_scale=6.0, -).frames[0] -export_to_video(output, "output.mp4", fps=24) -``` - - - - -[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. - -```py -# pip install ftfy -import torch -import numpy as np -from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline -from diffusers.hooks.group_offloading import apply_group_offloading -from diffusers.utils import export_to_video -from transformers import UMT5EncoderModel - -text_encoder = UMT5EncoderModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16) -vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32) -transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16) - -pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - "Skywork/SkyReels-V2-DF-14B-540P-Diffusers", - vae=vae, - transformer=transformer, - text_encoder=text_encoder, - torch_dtype=torch.bfloat16 -) -pipeline.to("cuda") +pipe = pipe.to("cuda") +pipe.transformer.set_ar_attention(causal_block_size=5) -# torch.compile -pipeline.transformer.to(memory_format=torch.channels_last) -pipeline.transformer = torch.compile( - pipeline.transformer, mode="max-autotune", fullgraph=True -) +prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window." -prompt = """ -The camera rushes from far to near in a low-angle shot, -revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in -for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground. -Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic -shadows and warm highlights. Medium composition, front view, low angle, with depth of field. -""" -negative_prompt = """ -Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, -low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, -misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards -""" - -output = pipeline( +output = pipe( prompt=prompt, - negative_prompt=negative_prompt, + num_inference_steps=30, + height=544, + width=960, num_frames=97, - guidance_scale=6.0, + ar_step=5, # Controls asynchronous inference (0 for synchronous mode) + overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long + addnoise_condition=20, # Improves consistency in long video generation ).frames[0] -export_to_video(output, "output.mp4", fps=24) +export_to_video(output, "T2V.mp4", fps=24, quality=8) ``` @@ -181,14 +120,12 @@ import torch import torchvision.transforms.functional as TF from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline from diffusers.utils import export_to_video, load_image -from transformers import CLIPVisionModel model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers" -image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32) vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained( - model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 + model_id, vae=vae, torch_dtype=torch.bfloat16 ) pipe.to("cuda") @@ -230,18 +167,6 @@ export_to_video(output, "output.mp4", fps=24) -### Any-to-Video Controllable Generation - -SkyReels-V2 supports various generation techniques which achieve controllable video generation. Some of the capabilities include: -- Control to Video (Depth, Pose, Sketch, Flow, Grayscale, Scribble, Layout, Boundary Box, etc.). Recommended library for preprocessing videos to obtain control videos: [huggingface/controlnet_aux]() -- Image/Video to Video (first frame, last frame, starting clip, ending clip, random clips) -- Inpainting and Outpainting -- Subject to Video (faces, object, characters, etc.) -- Composition to Video (reference anything, animate anything, swap anything, expand anything, move anything, etc.) - -The general rule of thumb to keep in mind when preparing inputs for the SkyReels-V2 pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color. - -The code snippets available in [this](https://github.com/huggingface/diffusers/pull/11582) pull request demonstrate some examples of how videos can be generated with controllability signals. ## Notes @@ -254,7 +179,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p # pip install ftfy import torch from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline - from diffusers import FlowMatchUniPCMultistepScheduler from diffusers.utils import export_to_video vae = AutoModel.from_pretrained( @@ -263,9 +187,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16 ) - pipeline.scheduler = FlowMatchUniPCMultistepScheduler.from_config( - pipeline.scheduler.config, flow_shift=5.0 - ) pipeline.to("cuda") pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie") @@ -292,38 +213,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
-- [`SkyReelsV2Transformer3DModel`] and [`AutoencoderKLWan`] support loading from single files with [`~loaders.FromSingleFileMixin.from_single_file`]. - -
- Show example code - - ```py - # pip install ftfy - import torch - from diffusers import SkyReelsV2DiffusionForcingPipeline, AutoModel - - vae = AutoModel.from_single_file( - "https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers/blob/main/split_files/vae/skyreels_v2_vae.safetensors" - ) - transformer = AutoModel.from_single_file( - "https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers/blob/main/split_files/diffusion_models/skyreels_v2_df_1.3b_bf16.safetensors", - torch_dtype=torch.bfloat16 - ) - pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained( - "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", - vae=vae, - transformer=transformer, - torch_dtype=torch.bfloat16 - ) - ``` - -
- -- Set the [`AutoencoderKLWan`] dtype to `torch.float32` for better decoding quality. - -- The number of frames per second (fps) or `k` should be calculated by `4 * k + 1`. - -- Try lower `shift` values (`2.0` to `5.0`) for lower resolution videos and higher `shift` values (`7.0` to `12.0`) for higher resolution images. ## SkyReelsV2DiffusionForcingPipeline From ebc77147ed7394ad0074c42e2678095f8a2893e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 13:14:04 +0300 Subject: [PATCH 258/264] Refactor video processing in SkyReelsV2DiffusionForcingPipeline - Renamed parameters for clarity: `video` to `video_latents` and `overlap_history` to `overlap_history_latent_frames`. - Updated logic for handling long video generation, including adjustments to latent frame calculations and accumulation. - Consolidated handling of latents for both long and short video generation scenarios. - Final decoding step now consistently converts latents to pixels, ensuring proper output format. --- .../pipeline_skyreels_v2_diffusion_forcing.py | 266 +++++------------- 1 file changed, 74 insertions(+), 192 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 2c5283f6db37..c52e274e5871 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -359,10 +359,9 @@ def prepare_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, base_num_frames: Optional[int] = None, - video: Optional[torch.Tensor] = None, - overlap_history: Optional[int] = None, + video_latents: Optional[torch.Tensor] = None, causal_block_size: Optional[int] = None, - overlap_history_frames: Optional[int] = None, + overlap_history_latent_frames: Optional[int] = None, long_video_iter: Optional[int] = None, ) -> torch.Tensor: if latents is not None: @@ -375,20 +374,8 @@ def prepare_latents( prefix_video_latents = None prefix_video_latents_length = 0 - if video is not None: # long video generation at the iterations other than the first one - prefix_video_latents = retrieve_latents( - self.vae.encode(video[:, :, -overlap_history:]), sample_mode="argmax" - ) - - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(device, self.vae.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - device, self.vae.dtype - ) - prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std + if video_latents is not None: # long video generation at the iterations other than the first one + prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:] if prefix_video_latents.shape[2] % causal_block_size != 0: truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size @@ -400,9 +387,9 @@ def prepare_latents( prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] - finished_frame_num = long_video_iter * (base_num_frames - overlap_history_frames) + overlap_history_frames + finished_frame_num = long_video_iter * (base_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames left_frame_num = num_latent_frames - finished_frame_num - num_latent_frames = min(left_frame_num + overlap_history_frames, base_num_frames) + num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_num_frames) elif base_num_frames is not None: # long video generation at the first iteration num_latent_frames = base_num_frames else: # short video generation @@ -718,11 +705,34 @@ def __call__( fps_embeds = [fps] * prompt_embeds.shape[0] fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # Short video generation + # Determine if we're doing long video generation + is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames + # Initialize accumulated_latents to store all latents in one tensor + accumulated_latents = None + if is_long_video: + # Long video generation setup + overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + base_num_frames = ( + (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 + if base_num_frames is not None + else num_latent_frames + ) + n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_latent_frames) + 1 + + else: + # Short video generation setup + n_iter = 1 + base_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # Loop through iterations (multiple iterations only for long videos) + for iter_idx in range(n_iter): + if is_long_video: + print(f"long_video_iter:{iter_idx}") + # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, num_latent_frames, _, prefix_video_latents_length = self.prepare_latents( + latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, height, @@ -731,26 +741,35 @@ def __call__( torch.float32, device, generator, - latents, + latents if iter_idx == 0 else None, + video_latents=accumulated_latents, # Pass latents directly instead of decoded video + base_num_frames=base_num_frames if is_long_video else None, + causal_block_size=causal_block_size, + overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None, + long_video_iter=iter_idx if is_long_video else None, ) - base_num_frames = ( - (num_frames - 1) // self.vae_scale_factor_temporal + 1 - if base_num_frames is not None - else num_latent_frames - ) + if prefix_video_latents_length > 0: + latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) - # 4. Prepare sample schedulers and timestep matrix + # 6. Prepare sample schedulers and timestep matrix sample_schedulers = [] - for _ in range(num_latent_frames): + for _ in range(current_num_latent_frames): sample_scheduler = deepcopy(self.scheduler) sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) sample_schedulers.append(sample_scheduler) + + # Different matrix generation for short vs long video step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, timesteps, base_num_frames, ar_step, prefix_video_latents_length, causal_block_size + current_num_latent_frames, + timesteps, + current_num_latent_frames if is_long_video else base_num_frames, + ar_step, + prefix_video_latents_length, + causal_block_size, ) - # 6. Denoising loop + # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(step_matrix) @@ -828,172 +847,35 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - else: - # Long video generation - overlap_history_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1 - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - base_num_frames = ( - (base_num_frames - 1) // self.vae_scale_factor_temporal + 1 - if base_num_frames is not None - else num_latent_frames - ) - n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - video = None - for long_video_iter in range(n_iter): - print(f"long_video_iter:{long_video_iter}") - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents, num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents if long_video_iter == 0 else None, - video=video, - overlap_history=overlap_history, - base_num_frames=base_num_frames, - causal_block_size=causal_block_size, - overlap_history_frames=overlap_history_frames, - long_video_iter=long_video_iter, - ) - - if prefix_video_latents_length > 0: - latents[:, :, :prefix_video_latents_length, :, :] = prefix_video_latents.to(transformer_dtype) - - # 4. Prepare sample schedulers and timestep matrix - sample_schedulers = [] - for _ in range(num_latent_frames): - sample_scheduler = deepcopy(self.scheduler) - sample_scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) - sample_schedulers.append(sample_scheduler) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - num_latent_frames, - timesteps, - num_latent_frames, - ar_step, - prefix_video_latents_length, - causal_block_size, - ) - - # 6. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(step_matrix) - with self.progress_bar(total=len(step_matrix)) as progress_bar: - for i, t in enumerate(step_matrix): - if self.interrupt: - continue - - self._current_timestep = t - valid_interval_start, valid_interval_end = valid_interval[i] - latent_model_input = ( - latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone() - ) - timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone() - - if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_length: - noise_factor = 0.001 * addnoise_condition - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] = ( - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[:, :, valid_interval_start:prefix_video_latents_length, :, :] - ) - * noise_factor - ) - timestep[:, valid_interval_start:prefix_video_latents_length] = addnoise_condition - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - - update_mask_i = step_update_mask[i] - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[:, :, idx, :, :] = sample_schedulers[idx].step( - noise_pred[:, :, idx - valid_interval_start, :, :], - t[idx], - latents[:, :, idx, :, :], - return_dict=False, - generator=generator, - )[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop( - "negative_prompt_embeds", negative_prompt_embeds - ) - - # call the callback, if provided - if i == len(step_matrix) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if not output_type == "latent": - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( - 1, self.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) - latents = latents / latents_std + latents_mean - videos = self.vae.decode(latents, return_dict=False)[0] - if video is None: - video = videos - else: - video = torch.cat([video, videos[:, :, overlap_history:]], dim=2) + # Handle latent accumulation for long videos or use the current latents for short videos + if is_long_video: + if accumulated_latents is None: + accumulated_latents = latents else: - video = latents + # Keep overlap frames for conditioning but don't include them in final output + accumulated_latents = torch.cat( + [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2 + ) + + if is_long_video: + latents = accumulated_latents self._current_timestep = None + # Final decoding step - convert latents to pixels if not output_type == "latent": - if overlap_history is None: - latents = latents.to(self.vae.dtype) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( - 1, self.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) - latents = latents / latents_std + latents_mean - video = self.vae.decode(latents, return_dict=False)[0] + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From fa715d6510dd16797aac68ba2c85638941bee032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 13:25:32 +0300 Subject: [PATCH 259/264] Update activation function in `fps_projection` of `SkyReelsV2Transformer3DModel` - Changed activation function from `silu` to `linear-silu` in the `fps_projection` layer for improved performance and integration. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index ad81f19cdf46..53aaab6baa0d 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -439,7 +439,7 @@ def __init__( if inject_sample_info: self.fps_embedding = nn.Embedding(2, inner_dim) - self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="silu") + self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu") def forward( self, From 56ea4387479fff7d35bde79229cbf6b136fb9609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 13:39:17 +0300 Subject: [PATCH 260/264] Add fps_projection layer renaming in convert_skyreelsv2_to_diffusers.py - Updated key mappings for the `fps_projection` layer to align with new naming conventions, ensuring consistency in model integration. --- scripts/convert_skyreelsv2_to_diffusers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py index ef9f57b8e5b8..8117e768dbfc 100644 --- a/scripts/convert_skyreelsv2_to_diffusers.py +++ b/scripts/convert_skyreelsv2_to_diffusers.py @@ -30,6 +30,8 @@ "modulation": "scale_shift_table", "ffn.0": "ffn.net.0.proj", "ffn.2": "ffn.net.2", + "fps_projection.0": "fps_projection.net.0.proj", + "fps_projection.2": "fps_projection.net.2", # Hack to swap the layer names # The original model calls the norms in following order: norm1, norm3, norm2 # We convert it to: norm1, norm2, norm3 From 829d632cf77792c2125070fa191c34e351693f7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 15:14:35 +0300 Subject: [PATCH 261/264] Fix fps_projection assignment in SkyReelsV2Transformer3DModel - Corrected the assignment of the `fps_projection` layer to ensure it is properly cast to the appropriate data type, enhancing model functionality. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 53aaab6baa0d..abf6a4cd869c 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -508,7 +508,7 @@ def forward( fps_emb = self.fps_embedding(fps).float() timestep_proj = timestep_proj.to(fps_emb.dtype) - self.fps_projection.to(fps_emb.dtype) + self.fps_projection = self.fps_projection.to(fps_emb.dtype) if enable_diffusion_forcing: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat( timestep.shape[1], 1, 1 From 0b7d7ea1592733f834298c821194b506ee902ebd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 15:27:08 +0300 Subject: [PATCH 262/264] Update _keep_in_fp32_modules in SkyReelsV2Transformer3DModel - Added `fps_projection` to the list of modules that should remain in FP32 precision, ensuring proper handling of data types during model operations. --- src/diffusers/models/transformers/transformer_skyreels_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index abf6a4cd869c..8cef3214b49d 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -374,7 +374,7 @@ class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fr _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] _no_split_modules = ["SkyReelsV2TransformerBlock"] - _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3", "fps_projection"] _keys_to_ignore_on_load_unexpected = ["norm_added_q"] @register_to_config @@ -508,7 +508,6 @@ def forward( fps_emb = self.fps_embedding(fps).float() timestep_proj = timestep_proj.to(fps_emb.dtype) - self.fps_projection = self.fps_projection.to(fps_emb.dtype) if enable_diffusion_forcing: timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat( timestep.shape[1], 1, 1 From 6a1f857225dd3bf055137fb2e58976377a6fa43c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 18:13:17 +0300 Subject: [PATCH 263/264] Remove integration test classes from SkyReelsV2 test files - Deleted the `SkyReelsV2DiffusionForcingPipelineIntegrationTests` and `SkyReelsV2PipelineIntegrationTests` classes along with their associated setup, teardown, and test methods, as they were not implemented and not needed for current testing. --- .../pipelines/skyreels_v2/test_skyreels_v2.py | 20 ------------------- .../skyreels_v2/test_skyreels_v2_df.py | 20 ------------------- 2 files changed, 40 deletions(-) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py index c2e49ec98a15..c74fd342ff15 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py @@ -138,23 +138,3 @@ def test_inference(self): @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass - - -@slow -@require_torch_accelerator -class SkyReelsV2PipelineIntegrationTests(unittest.TestCase): - prompt = "A painting of a squirrel eating a burger." - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - @unittest.skip("TODO: test needs to be implemented") - def test_SkyReelsV2x(self): - pass diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py index 48aec05eeee6..d8ff0d439e9d 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py @@ -138,23 +138,3 @@ def test_inference(self): @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass - - -@slow -@require_torch_accelerator -class SkyReelsV2DiffusionForcingPipelineIntegrationTests(unittest.TestCase): - prompt = "A painting of a squirrel eating a burger." - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - @unittest.skip("TODO: test needs to be implemented") - def test_SkyReelsV2DiffusionForcingx(self): - pass From 2d35933296eba54dcb73f0d911b085964363a68c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tolga=20Cang=C3=B6z?= Date: Mon, 16 Jun 2025 18:15:08 +0300 Subject: [PATCH 264/264] style --- .../pipeline_skyreels_v2_diffusion_forcing.py | 46 +++++++++++-------- .../pipelines/skyreels_v2/test_skyreels_v2.py | 3 -- .../skyreels_v2/test_skyreels_v2_df.py | 3 -- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index c52e274e5871..81d8f3cf80e8 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -387,7 +387,9 @@ def prepare_latents( prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents] prefix_video_latents_length = prefix_video_latents.shape[2] - finished_frame_num = long_video_iter * (base_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames + finished_frame_num = ( + long_video_iter * (base_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames + ) left_frame_num = num_latent_frames - finished_frame_num num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_num_frames) elif base_num_frames is not None: # long video generation at the first iteration @@ -718,7 +720,9 @@ def __call__( if base_num_frames is not None else num_latent_frames ) - n_iter = 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_latent_frames) + 1 + n_iter = ( + 1 + (num_latent_frames - base_num_frames - 1) // (base_num_frames - overlap_history_latent_frames) + 1 + ) else: # Short video generation setup @@ -732,21 +736,23 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels - latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_length = self.prepare_latents( - batch_size * num_videos_per_prompt, - num_channels_latents, - height, - width, - num_frames, - torch.float32, - device, - generator, - latents if iter_idx == 0 else None, - video_latents=accumulated_latents, # Pass latents directly instead of decoded video - base_num_frames=base_num_frames if is_long_video else None, - causal_block_size=causal_block_size, - overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None, - long_video_iter=iter_idx if is_long_video else None, + latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_length = ( + self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents if iter_idx == 0 else None, + video_latents=accumulated_latents, # Pass latents directly instead of decoded video + base_num_frames=base_num_frames if is_long_video else None, + causal_block_size=causal_block_size, + overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None, + long_video_iter=iter_idx if is_long_video else None, + ) ) if prefix_video_latents_length > 0: @@ -871,9 +877,9 @@ def __call__( .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( - 1, self.vae.config.z_dim, 1, 1, 1 - ).to(latents.device, latents.dtype) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) latents = latents / latents_std + latents_mean video = self.vae.decode(latents, return_dict=False)[0] video = self.video_processor.postprocess_video(video, output_type=output_type) diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py index c74fd342ff15..49d61ecb89e3 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest import numpy as np @@ -27,8 +26,6 @@ ) from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_accelerator, - slow, ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py index d8ff0d439e9d..e8f4f74e7da8 100644 --- a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py +++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest import numpy as np @@ -27,8 +26,6 @@ ) from diffusers.utils.testing_utils import ( enable_full_determinism, - require_torch_accelerator, - slow, ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS