From 59d68e6e1c49fe2df1e1b483dc0fe6621467435a Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Thu, 8 Feb 2024 01:24:12 +0900 Subject: [PATCH] fix for negpip --- scripts/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/attention.py b/scripts/attention.py index da89a46..8d92af8 100644 --- a/scripts/attention.py +++ b/scripts/attention.py @@ -20,6 +20,7 @@ def main_forward(module,x,context,mask,divide,isvanilla = False,userpp = False,t if negpip: conds, contokens = negpip context = torch.cat((context,conds),1) + context = context.to(x.dtype) h = module.heads if isvanilla: # SBM Ddim / plms have the context split ahead along with x. @@ -596,4 +597,4 @@ def negpipdealer(i,pn): else: return None else: - return None \ No newline at end of file + return None