diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index 6d1e90ca5f..9e1b1c4097 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -396,6 +396,7 @@ def _forward_encoder( - g: :math:`(B, C)` """ if hasattr(self, "emb_g"): + g = g.type(torch.LongTensor) g = self.emb_g(g) # [B, C, 1] if g is not None: g = g.unsqueeze(-1) @@ -683,9 +684,10 @@ def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # p # encoder pass o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass - o_dr_log = self.duration_predictor(o_en, x_mask) + o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) + # pitch predictor pass o_pitch = None if self.args.use_pitch: