diff --git a/nbs/4B. Multi-language semantic to acoustic token modeling.ipynb b/nbs/4B. Multi-language semantic to acoustic token modeling.ipynb index 4437403..f90d409 100644 --- a/nbs/4B. Multi-language semantic to acoustic token modeling.ipynb +++ b/nbs/4B. Multi-language semantic to acoustic token modeling.ipynb @@ -561,7 +561,7 @@ " if self.positional_embeddings is not None: semb = semb + self.positional_embeddings\n", " positions = torch.arange(0, semb.shape[1], device=semb.device)\n", " xenc = self._encoder(semb, positions)\n", - " if self.training:\n", + " if self.training and self.tunables.causal_encoder:\n", " enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()\n", " enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)\n", " else:\n", diff --git a/whisperspeech/s2a_delar_mup_wds_mlang.py b/whisperspeech/s2a_delar_mup_wds_mlang.py index c15afc3..4a8e1ac 100644 --- a/whisperspeech/s2a_delar_mup_wds_mlang.py +++ b/whisperspeech/s2a_delar_mup_wds_mlang.py @@ -348,7 +348,7 @@ def run_encoder(self, Stoks, speakers): if self.positional_embeddings is not None: semb = semb + self.positional_embeddings positions = torch.arange(0, semb.shape[1], device=semb.device) xenc = self._encoder(semb, positions) - if self.training: + if self.training and self.tunables.causal_encoder: enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float() enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width) else: