diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index e2da672cf9c7..5f2227a999db 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -317,7 +317,7 @@ def forward( # Predict energy if self.energy_predictor is not None: - energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1) + energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1) if energy is not None: # Average energy over characters @@ -402,7 +402,7 @@ def infer( assert energy.shape[-1] == text.shape[-1], f"energy.shape[-1]: {energy.shape[-1]} != len(text)" energy_emb = self.energy_emb(energy) else: - energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1) + energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1) energy_emb = self.energy_emb(energy_pred.unsqueeze(1)) enc_out = enc_out + energy_emb.transpose(1, 2)