Skip to content

Commit

Permalink
Fixed S2A training without causal_encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jpc committed Mar 2, 2024
1 parent fb8129d commit 478733d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@
" if self.positional_embeddings is not None: semb = semb + self.positional_embeddings\n",
" positions = torch.arange(0, semb.shape[1], device=semb.device)\n",
" xenc = self._encoder(semb, positions)\n",
" if self.training:\n",
" if self.training and self.tunables.causal_encoder:\n",
" enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()\n",
" enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)\n",
" else:\n",
Expand Down
2 changes: 1 addition & 1 deletion whisperspeech/s2a_delar_mup_wds_mlang.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def run_encoder(self, Stoks, speakers):
if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
positions = torch.arange(0, semb.shape[1], device=semb.device)
xenc = self._encoder(semb, positions)
if self.training:
if self.training and self.tunables.causal_encoder:
enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()
enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
else:
Expand Down

0 comments on commit 478733d

Please sign in to comment.