Skip to content

Commit 79d4e81

Browse files
committed
fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are #12509
1 parent 7e77a38 commit 79d4e81

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

modules/processing.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,18 @@ def get_token_merging_ratio(self, for_hr=False):
382382
def setup_prompts(self):
383383
if type(self.prompt) == list:
384384
self.all_prompts = self.prompt
385+
elif type(self.negative_prompt) == list:
386+
self.all_prompts = [self.prompt] * len(self.negative_prompt)
385387
else:
386388
self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
387389

388390
if type(self.negative_prompt) == list:
389391
self.all_negative_prompts = self.negative_prompt
390392
else:
391-
self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
393+
self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)
394+
395+
if len(self.all_prompts) != len(self.all_negative_prompts):
396+
raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})")
392397

393398
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
394399
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]

0 commit comments

Comments
 (0)