Skip to content

Commit 841504b

Browse files
goirihlkygithub-actions[bot]yiyixuxu
authored
Add support to pass image embeddings to the WAN I2V pipeline. (#11175)
* Add support to pass image embeddings to the pipeline. --------- Co-authored-by: hlky <hlky@hlky.ac> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent fc7a867 commit 841504b

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,19 @@ def check_inputs(
321321
width,
322322
prompt_embeds=None,
323323
negative_prompt_embeds=None,
324+
image_embeds=None,
324325
callback_on_step_end_tensor_inputs=None,
325326
):
326-
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
327+
if image is not None and image_embeds is not None:
328+
raise ValueError(
329+
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
330+
" only forward one of the two."
331+
)
332+
if image is None and image_embeds is None:
333+
raise ValueError(
334+
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
335+
)
336+
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
327337
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
328338
if height % 16 != 0 or width % 16 != 0:
329339
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -463,6 +473,7 @@ def __call__(
463473
latents: Optional[torch.Tensor] = None,
464474
prompt_embeds: Optional[torch.Tensor] = None,
465475
negative_prompt_embeds: Optional[torch.Tensor] = None,
476+
image_embeds: Optional[torch.Tensor] = None,
466477
output_type: Optional[str] = "np",
467478
return_dict: bool = True,
468479
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -512,6 +523,12 @@ def __call__(
512523
prompt_embeds (`torch.Tensor`, *optional*):
513524
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
514525
provided, text embeddings are generated from the `prompt` input argument.
526+
negative_prompt_embeds (`torch.Tensor`, *optional*):
527+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
528+
provided, text embeddings are generated from the `negative_prompt` input argument.
529+
image_embeds (`torch.Tensor`, *optional*):
530+
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
531+
image embeddings are generated from the `image` input argument.
515532
output_type (`str`, *optional*, defaults to `"pil"`):
516533
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
517534
return_dict (`bool`, *optional*, defaults to `True`):
@@ -556,6 +573,7 @@ def __call__(
556573
width,
557574
prompt_embeds,
558575
negative_prompt_embeds,
576+
image_embeds,
559577
callback_on_step_end_tensor_inputs,
560578
)
561579

@@ -599,7 +617,8 @@ def __call__(
599617
if negative_prompt_embeds is not None:
600618
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
601619

602-
image_embeds = self.encode_image(image, device)
620+
if image_embeds is None:
621+
image_embeds = self.encode_image(image, device)
603622
image_embeds = image_embeds.repeat(batch_size, 1, 1)
604623
image_embeds = image_embeds.to(transformer_dtype)
605624

0 commit comments

Comments
 (0)