Skip to content

Commit

Permalink
fix the mel/prediction in the trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 6, 2024
1 parent f37c210 commit 7969f7a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __getitem__(self, shapes: str):
Int = TorchTyping(jaxtyping.Int)
Bool = TorchTyping(jaxtyping.Bool)

E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred'])
E2TTSReturn = namedtuple('E2TTS', ['loss', 'cond', 'pred_flow', 'pred_data'])

# helpers

Expand Down Expand Up @@ -1168,4 +1168,4 @@ def forward(

loss = loss[rand_span_mask].mean()

return E2TTSReturn(loss, cond, pred)
return E2TTSReturn(loss, cond, pred, x0 + pred)
12 changes: 8 additions & 4 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def to_numpy(t):
return t.detach().cpu().numpy()

# plot spectrogram
def plot_spectrogram(spectrogram):
spectrogram = to_numpy(spectrogram)
fig, ax = plt.subplots(figsize=(10, 4))
im = ax.imshow(spectrogram.T, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
Expand Down Expand Up @@ -238,7 +242,7 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
self.writer.add_scalar('duration loss', dur_loss.item(), global_step)

loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
loss, cond, pred, pred_data = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
self.accelerator.backward(loss)

if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
Expand All @@ -262,9 +266,9 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100

if global_step % save_step == 0:
self.save_checkpoint(global_step)
self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:].detach().cpu().numpy()), global_step)
self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:].detach().cpu().numpy()), global_step)
self.writer.add_figure("mel/prediction", plot_spectrogram(pred[0,:,:].detach().cpu().numpy()), global_step)
self.writer.add_figure("mel/target", plot_spectrogram(mel_spec[0,:,:]), global_step)
self.writer.add_figure("mel/mask", plot_spectrogram(cond[0,:,:]), global_step)
self.writer.add_figure("mel/prediction", plot_spectrogram(pred_data[0,:,:]), global_step)

epoch_loss /= len(train_dataloader)
if self.accelerator.is_local_main_process:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.0.5"
version = "1.0.6"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 7969f7a

Please sign in to comment.