From 451763620b5369f5fd250367a0fcd1f3e0a715e6 Mon Sep 17 00:00:00 2001 From: Werner Oswald Date: Mon, 27 Feb 2023 15:12:46 +0100 Subject: [PATCH] less bugs, a fix to large prompts where we got an issue when going for speed. there we run into issues because of different tensorsizes with torch.cat. maybe someday I find how to fix that, for now we test the size and if different we fall back to slow but safe --- backend/deforum/six/model_wrap.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backend/deforum/six/model_wrap.py b/backend/deforum/six/model_wrap.py index e6337c1..4f993a6 100644 --- a/backend/deforum/six/model_wrap.py +++ b/backend/deforum/six/model_wrap.py @@ -155,14 +155,17 @@ def _cfg_model(x, sigma, cond, **kwargs): # No conditioning else: # calculate cond and uncond simultaneously - if self.cond_uncond_sync: + # dows only work with prompts of the same size, + # so we check and if the size is different we go for the slower variant + c_size = cond.size() + uc_size = uncond.size() + if self.cond_uncond_sync and c_size == uc_size: cond_in = torch.cat([uncond, cond]) x0 = _cfg_model(x, sigma, cond=cond_in) else: uncond = self.inner_model(x, sigma, cond=uncond) cond = self.inner_model(x, sigma, cond=cond) x0 = uncond + (cond - uncond) * cond_scale - return x0 def make_cond_fn(self, loss_fn, scale):