From eef788981cbed7c68ffd58b4eb22a2df2e59ae0b Mon Sep 17 00:00:00 2001 From: Olivier Louvignes Date: Tue, 6 Sep 2022 12:41:08 +0200 Subject: [PATCH] feat(txt2img): allow from_file to work with len(lines) < batch_size (#349) --- scripts/orig_scripts/txt2img.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/orig_scripts/txt2img.py b/scripts/orig_scripts/txt2img.py index 9f01bca021a..6c43e73b93f 100644 --- a/scripts/orig_scripts/txt2img.py +++ b/scripts/orig_scripts/txt2img.py @@ -232,7 +232,12 @@ def forward(self, x, sigma, uncond, cond, cond_scale): print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() - data = list(chunk(data, batch_size)) + if (len(data) >= batch_size): + data = list(chunk(data, batch_size)) + else: + while (len(data) < batch_size): + data.append(data[-1]) + data = [data] sample_path = os.path.join(outpath, "samples") os.makedirs(sample_path, exist_ok=True) @@ -264,7 +269,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - + if not opt.klms: samples_ddim, _ = sampler.sample(S=opt.ddim_steps, conditioning=c, @@ -284,7 +289,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale): model_wrap_cfg = CFGDenoiser(model_wrap) extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args) - + x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)