From 65ba30a9f1c756ed11cb02527d9767975b5bc9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Piotr=20C=C5=82apa?= Date: Sat, 23 Mar 2024 15:26:56 +0000 Subject: [PATCH] t2s: fixed batched generation --- nbs/5B. Multi-lang text to semantic token modeling.ipynb | 4 ++-- nbs/A. Neural modules.ipynb | 2 ++ whisperspeech/modules.py | 2 ++ whisperspeech/t2s_up_wds_mlang_enclm.py | 4 ++-- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nbs/5B. Multi-lang text to semantic token modeling.ipynb b/nbs/5B. Multi-lang text to semantic token modeling.ipynb index 388a4ca..67b134d 100644 --- a/nbs/5B. Multi-lang text to semantic token modeling.ipynb +++ b/nbs/5B. Multi-lang text to semantic token modeling.ipynb @@ -544,7 +544,7 @@ "\n", " cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)\n", " cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1\n", - " return self.cps_embeddings(cps_bin).unsqueeze(1)\n", + " return self.cps_embeddings(cps_bin)\n", "\n", " def run_encoder(self, in_ttoks, languages, cpss):\n", " if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)\n", @@ -717,7 +717,7 @@ " toks_positions = torch.arange(N+1, device=dev)\n", " \n", " with record_function(\"prefill\"):\n", - " toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)\n", + " toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]\n", " with inference.inference_context():\n", " for i in it:\n", " toks[:,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]\n", diff --git a/nbs/A. Neural modules.ipynb b/nbs/A. Neural modules.ipynb index 5800c6a..c2ec6c2 100644 --- a/nbs/A. Neural modules.ipynb +++ b/nbs/A. Neural modules.ipynb @@ -166,6 +166,8 @@ " causal = False,\n", " mask=None,\n", " ):\n", + " if self.k_cache is not None:\n", + " assert qx.shape[0] <= self.k_cache.shape[0], \"please pass in a larger max_batch_size to setup_kv_cache\"\n", " if self.qkv:\n", " q,k,v = self.qkv(qx).split(self.odim, dim=-1)\n", " elif self.kv:\n", diff --git a/whisperspeech/modules.py b/whisperspeech/modules.py index 698a976..b642519 100644 --- a/whisperspeech/modules.py +++ b/whisperspeech/modules.py @@ -117,6 +117,8 @@ def forward( causal = False, mask=None, ): + if self.k_cache is not None: + assert qx.shape[0] <= self.k_cache.shape[0], "please pass in a larger max_batch_size to setup_kv_cache" if self.qkv: q,k,v = self.qkv(qx).split(self.odim, dim=-1) elif self.kv: diff --git a/whisperspeech/t2s_up_wds_mlang_enclm.py b/whisperspeech/t2s_up_wds_mlang_enclm.py index 52b7bb1..93db9d0 100644 --- a/whisperspeech/t2s_up_wds_mlang_enclm.py +++ b/whisperspeech/t2s_up_wds_mlang_enclm.py @@ -284,7 +284,7 @@ def _embed_cps(self, cpss): cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long) cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1 - return self.cps_embeddings(cps_bin).unsqueeze(1) + return self.cps_embeddings(cps_bin) def run_encoder(self, in_ttoks, languages, cpss): if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages) @@ -457,7 +457,7 @@ def generate(self, txt, cps=15, lang="en", stoks_prompt=None, N=None, bs=1, T=0. toks_positions = torch.arange(N+1, device=dev) with record_function("prefill"): - toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k) + toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0] with inference.inference_context(): for i in it: toks[:,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]