Skip to content

Commit

Permalink
t2s: fixed batched generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jpc committed Mar 23, 2024
1 parent 2909b06 commit 65ba30a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions nbs/5B. Multi-lang text to semantic token modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@
"\n",
" cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)\n",
" cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1\n",
" return self.cps_embeddings(cps_bin).unsqueeze(1)\n",
" return self.cps_embeddings(cps_bin)\n",
"\n",
" def run_encoder(self, in_ttoks, languages, cpss):\n",
" if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)\n",
Expand Down Expand Up @@ -717,7 +717,7 @@
" toks_positions = torch.arange(N+1, device=dev)\n",
" \n",
" with record_function(\"prefill\"):\n",
" toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)\n",
" toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]\n",
" with inference.inference_context():\n",
" for i in it:\n",
" toks[:,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]\n",
Expand Down
2 changes: 2 additions & 0 deletions nbs/A. Neural modules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@
" causal = False,\n",
" mask=None,\n",
" ):\n",
" if self.k_cache is not None:\n",
" assert qx.shape[0] <= self.k_cache.shape[0], \"please pass in a larger max_batch_size to setup_kv_cache\"\n",
" if self.qkv:\n",
" q,k,v = self.qkv(qx).split(self.odim, dim=-1)\n",
" elif self.kv:\n",
Expand Down
2 changes: 2 additions & 0 deletions whisperspeech/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def forward(
causal = False,
mask=None,
):
if self.k_cache is not None:
assert qx.shape[0] <= self.k_cache.shape[0], "please pass in a larger max_batch_size to setup_kv_cache"
if self.qkv:
q,k,v = self.qkv(qx).split(self.odim, dim=-1)
elif self.kv:
Expand Down
4 changes: 2 additions & 2 deletions whisperspeech/t2s_up_wds_mlang_enclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _embed_cps(self, cpss):

cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
return self.cps_embeddings(cps_bin).unsqueeze(1)
return self.cps_embeddings(cps_bin)

def run_encoder(self, in_ttoks, languages, cpss):
if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)
Expand Down Expand Up @@ -457,7 +457,7 @@ def generate(self, txt, cps=15, lang="en", stoks_prompt=None, N=None, bs=1, T=0.
toks_positions = torch.arange(N+1, device=dev)

with record_function("prefill"):
toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)
toks[:,start+1] = self.generate_one(toks[:,:start+1].contiguous(), toks_positions[:start+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]
with inference.inference_context():
for i in it:
toks[:,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)[:,0]
Expand Down

0 comments on commit 65ba30a

Please sign in to comment.