From b87a665a40337507ce80974e89e51a398a49bc1d Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Thu, 16 Nov 2023 14:41:12 -0300 Subject: [PATCH] Ensures that only GPT model is in training mode during training --- TTS/tts/layers/xtts/trainer/gpt_trainer.py | 7 ++++--- requirements.txt | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 005b30bede..4789e1f43f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -318,9 +318,10 @@ def eval_step(self, batch, criterion): batch["cond_idxs"] = None return self.train_step(batch, criterion) - def on_epoch_start(self, trainer): # pylint: disable=W0613 - # guarante that dvae will be in eval mode after .train() on evaluation end - self.dvae = self.dvae.eval() + def on_train_epoch_start(self, trainer): + trainer.model.eval() # the whole model to eval + # put gpt model in training mode + trainer.model.xtts.gpt.train() def on_init_end(self, trainer): # pylint: disable=W0613 # ignore similarities.pth on clearml save/upload diff --git a/requirements.txt b/requirements.txt index 836de40ab6..e4a816f28c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,7 +27,7 @@ pandas>=1.4,<2.0 # deps for training matplotlib==3.7.* # coqui stack -trainer +trainer>=0.0.32 # config management coqpit>=0.0.16 # chinese g2p deps