Skip to content

Commit

Permalink
[TTS] Fix FastPitch energy code (#6511)
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan <rlangman@nvidia.com>
  • Loading branch information
rlangman authored Apr 28, 2023
1 parent bdfb950 commit 92bb5c0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(

# Predict energy
if self.energy_predictor is not None:
energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1)
energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1)

if energy is not None:
# Average energy over characters
Expand Down Expand Up @@ -402,7 +402,7 @@ def infer(
assert energy.shape[-1] == text.shape[-1], f"energy.shape[-1]: {energy.shape[-1]} != len(text)"
energy_emb = self.energy_emb(energy)
else:
energy_pred = self.energy_predictor(prosody_input, enc_mask).squeeze(-1)
energy_pred = self.energy_predictor(enc_out, enc_mask, conditioning=spk_emb).squeeze(-1)
energy_emb = self.energy_emb(energy_pred.unsqueeze(1))
enc_out = enc_out + energy_emb.transpose(1, 2)

Expand Down

0 comments on commit 92bb5c0

Please sign in to comment.