Skip to content

Commit

Permalink
add: better warning messages when handling multiple conditionings. (h…
Browse files Browse the repository at this point in the history
…uggingface#2804)

* add: better warning messages when handling multiple conditioning.

* fix: handling of controlnet_conditioning_scale
  • Loading branch information
sayakpaul authored and w4ffl35 committed Apr 14, 2023
1 parent b47813f commit 6c48c75
Showing 1 changed file with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -537,15 +537,27 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

# Check `image`
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)

# Check `image`
if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

if len(image) != len(self.controlnet.nets):
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)
Expand All @@ -556,12 +568,14 @@ def check_inputs(
assert False

# Check `controlnet_conditioning_scale`

if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
Expand Down

0 comments on commit 6c48c75

Please sign in to comment.