From 566a9150601d191c493da44a6d32c7ba4733b0b0 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 17 Mar 2025 11:29:44 +0000 Subject: [PATCH 1/5] fix_wan_i2v_quality --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 102f1a5002e1..9bfa46dc5961 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -113,9 +113,11 @@ def retrieve_latents( latents_mean: torch.Tensor, latents_std: torch.Tensor, generator: Optional[torch.Generator] = None, - sample_mode: str = "sample", + sample_mode: str = "none", ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + if hasattr(encoder_output, "latent_dist") and sample_mode == "none": + return (encoder_output.latent_dist.mean - latents_mean) * latents_std + elif hasattr(encoder_output, "latent_dist") and sample_mode == "sample": encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std encoder_output.latent_dist.logvar = torch.clamp( (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 From 482415c2d02ff8006839c63df39181177172613b Mon Sep 17 00:00:00 2001 From: C Date: Mon, 17 Mar 2025 23:37:52 +0800 Subject: [PATCH 2/5] Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 9bfa46dc5961..c155a6a32994 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -117,7 +117,8 @@ def retrieve_latents( ): if hasattr(encoder_output, "latent_dist") and sample_mode == "none": return (encoder_output.latent_dist.mean - latents_mean) * latents_std - elif hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + else: + raise AttributeError("Could not access latents of provided encoder_output") encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std encoder_output.latent_dist.logvar = torch.clamp( (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 From c01d2bf78812546df17d9181db24b0774c3b1dd7 Mon Sep 17 00:00:00 2001 From: C Date: Mon, 17 Mar 2025 23:38:11 +0800 Subject: [PATCH 3/5] Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index c155a6a32994..5d28697484c5 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -113,7 +113,7 @@ def retrieve_latents( latents_mean: torch.Tensor, latents_std: torch.Tensor, generator: Optional[torch.Generator] = None, - sample_mode: str = "none", + sample_mode: str = "argmax", ): if hasattr(encoder_output, "latent_dist") and sample_mode == "none": return (encoder_output.latent_dist.mean - latents_mean) * latents_std From ac153827f57995d4f42f718450bc687e72a40d85 Mon Sep 17 00:00:00 2001 From: C Date: Mon, 17 Mar 2025 23:38:17 +0800 Subject: [PATCH 4/5] Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 5d28697484c5..bfc758c30c43 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -115,7 +115,7 @@ def retrieve_latents( generator: Optional[torch.Generator] = None, sample_mode: str = "argmax", ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "none": + if hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return (encoder_output.latent_dist.mean - latents_mean) * latents_std else: raise AttributeError("Could not access latents of provided encoder_output") From 378e412a993ad7c9566d474fa9c4c794cd5ed63a Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 17 Mar 2025 15:53:01 +0000 Subject: [PATCH 5/5] Update pipeline_wan_i2v.py --- .../pipelines/wan/pipeline_wan_i2v.py | 32 +++++-------------- 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index bfc758c30c43..e5699718ea71 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -108,34 +108,16 @@ def prompt_clean(text): return text +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, - latents_mean: torch.Tensor, - latents_std: torch.Tensor, - generator: Optional[torch.Generator] = None, - sample_mode: str = "argmax", + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): - if hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - return (encoder_output.latent_dist.mean - latents_mean) * latents_std - else: - raise AttributeError("Could not access latents of provided encoder_output") - encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std - encoder_output.latent_dist.logvar = torch.clamp( - (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 - ) - encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) - encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": - encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std - encoder_output.latent_dist.logvar = torch.clamp( - (encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0 - ) - encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar) - encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar) return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): - return (encoder_output.latents - latents_mean) * latents_std + return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") @@ -415,13 +397,15 @@ def prepare_latents( if isinstance(generator, list): latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator ] latent_condition = torch.cat(latent_condition) else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator) + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = (latent_condition - latents_mean) * latents_std + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) mask_lat_size[:, :, list(range(1, num_frames))] = 0 first_frame_mask = mask_lat_size[:, :, 0:1]