@@ -321,9 +321,19 @@ def check_inputs(
321
321
width ,
322
322
prompt_embeds = None ,
323
323
negative_prompt_embeds = None ,
324
+ image_embeds = None ,
324
325
callback_on_step_end_tensor_inputs = None ,
325
326
):
326
- if not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
327
+ if image is not None and image_embeds is not None :
328
+ raise ValueError (
329
+ f"Cannot forward both `image`: { image } and `image_embeds`: { image_embeds } . Please make sure to"
330
+ " only forward one of the two."
331
+ )
332
+ if image is None and image_embeds is None :
333
+ raise ValueError (
334
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
335
+ )
336
+ if image is not None and not isinstance (image , torch .Tensor ) and not isinstance (image , PIL .Image .Image ):
327
337
raise ValueError ("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" { type (image )} " )
328
338
if height % 16 != 0 or width % 16 != 0 :
329
339
raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -463,6 +473,7 @@ def __call__(
463
473
latents : Optional [torch .Tensor ] = None ,
464
474
prompt_embeds : Optional [torch .Tensor ] = None ,
465
475
negative_prompt_embeds : Optional [torch .Tensor ] = None ,
476
+ image_embeds : Optional [torch .Tensor ] = None ,
466
477
output_type : Optional [str ] = "np" ,
467
478
return_dict : bool = True ,
468
479
attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -512,6 +523,12 @@ def __call__(
512
523
prompt_embeds (`torch.Tensor`, *optional*):
513
524
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
514
525
provided, text embeddings are generated from the `prompt` input argument.
526
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
527
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
528
+ provided, text embeddings are generated from the `negative_prompt` input argument.
529
+ image_embeds (`torch.Tensor`, *optional*):
530
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
531
+ image embeddings are generated from the `image` input argument.
515
532
output_type (`str`, *optional*, defaults to `"pil"`):
516
533
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
517
534
return_dict (`bool`, *optional*, defaults to `True`):
@@ -556,6 +573,7 @@ def __call__(
556
573
width ,
557
574
prompt_embeds ,
558
575
negative_prompt_embeds ,
576
+ image_embeds ,
559
577
callback_on_step_end_tensor_inputs ,
560
578
)
561
579
@@ -599,7 +617,8 @@ def __call__(
599
617
if negative_prompt_embeds is not None :
600
618
negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
601
619
602
- image_embeds = self .encode_image (image , device )
620
+ if image_embeds is None :
621
+ image_embeds = self .encode_image (image , device )
603
622
image_embeds = image_embeds .repeat (batch_size , 1 , 1 )
604
623
image_embeds = image_embeds .to (transformer_dtype )
605
624
0 commit comments