diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index fd1d90849a66..487ad2d80ac6 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -321,9 +321,19 @@ def check_inputs( width, prompt_embeds=None, negative_prompt_embeds=None, + image_embeds=None, callback_on_step_end_tensor_inputs=None, ): - if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -463,6 +473,7 @@ def __call__( latents: Optional[torch.Tensor] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "np", return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, @@ -512,6 +523,12 @@ def __call__( prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `negative_prompt` input argument. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, + image embeddings are generated from the `image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -556,6 +573,7 @@ def __call__( width, prompt_embeds, negative_prompt_embeds, + image_embeds, callback_on_step_end_tensor_inputs, ) @@ -599,7 +617,8 @@ def __call__( if negative_prompt_embeds is not None: negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) - image_embeds = self.encode_image(image, device) + if image_embeds is None: + image_embeds = self.encode_image(image, device) image_embeds = image_embeds.repeat(batch_size, 1, 1) image_embeds = image_embeds.to(transformer_dtype)