From 7ed71d6fd89ca8bf2c4aefbb280e705b1d7ae6b8 Mon Sep 17 00:00:00 2001 From: Lordmau5 Date: Mon, 10 Apr 2023 00:32:25 +0200 Subject: [PATCH] fix(train): improve quality of training (#274) * fix(train): initialize `_temp_epoch` variable * chore(train): move `__init__` of LightningModule to the top * fix(train): fix order of optimizer as per Lightning.AI documentation * fix(train): remove `with torch.no_grad():` call for generator loss During several tests with a fresh model it introduced (or perhaps kept) some metallic noise. Letting it run without the `torch.no_grad()` call improves this significantly. * fix(train): ensure `log_audio_dict` uses correct `total_batch_idx` * fix(train): only save checkpoints for first `batch_idx` --- src/so_vits_svc_fork/train.py | 79 ++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/src/so_vits_svc_fork/train.py b/src/so_vits_svc_fork/train.py index 27ca5c8a..d30eebf7 100644 --- a/src/so_vits_svc_fork/train.py +++ b/src/so_vits_svc_fork/train.py @@ -82,6 +82,40 @@ def train( class VitsLightning(pl.LightningModule): + def __init__(self, reset_optimizer: bool = False, **hparams: Any): + super().__init__() + self._temp_epoch = 0 # Add this line to initialize the _temp_epoch attribute + self.save_hyperparameters("reset_optimizer") + self.save_hyperparameters(*[k for k in hparams.keys()]) + torch.manual_seed(self.hparams.train.seed) + self.net_g = SynthesizerTrn( + self.hparams.data.filter_length // 2 + 1, + self.hparams.train.segment_size // self.hparams.data.hop_length, + **self.hparams.model, + ) + self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm) + self.automatic_optimization = False + self.optim_g = torch.optim.AdamW( + self.net_g.parameters(), + self.hparams.train.learning_rate, + betas=self.hparams.train.betas, + eps=self.hparams.train.eps, + ) + self.optim_d = torch.optim.AdamW( + self.net_d.parameters(), + self.hparams.train.learning_rate, + betas=self.hparams.train.betas, + eps=self.hparams.train.eps, + ) + self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + self.optim_g, gamma=self.hparams.train.lr_decay + ) + self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + self.optim_d, gamma=self.hparams.train.lr_decay + ) + self.optimizers_count = 2 + self.load(reset_optimizer) + def on_train_start(self) -> None: self.set_current_epoch(self._temp_epoch) total_batch_idx = self._temp_epoch * len(self.trainer.train_dataloader) @@ -181,39 +215,6 @@ def load(self, reset_optimizer: bool = False): else: LOG.warning("No checkpoint found. Start from scratch.") - def __init__(self, reset_optimizer: bool = False, **hparams: Any): - super().__init__() - self.save_hyperparameters("reset_optimizer") - self.save_hyperparameters(*[k for k in hparams.keys()]) - torch.manual_seed(self.hparams.train.seed) - self.net_g = SynthesizerTrn( - self.hparams.data.filter_length // 2 + 1, - self.hparams.train.segment_size // self.hparams.data.hop_length, - **self.hparams.model, - ) - self.net_d = MultiPeriodDiscriminator(self.hparams.model.use_spectral_norm) - self.automatic_optimization = False - self.optim_g = torch.optim.AdamW( - self.net_g.parameters(), - self.hparams.train.learning_rate, - betas=self.hparams.train.betas, - eps=self.hparams.train.eps, - ) - self.optim_d = torch.optim.AdamW( - self.net_d.parameters(), - self.hparams.train.learning_rate, - betas=self.hparams.train.betas, - eps=self.hparams.train.eps, - ) - self.scheduler_g = torch.optim.lr_scheduler.ExponentialLR( - self.optim_g, gamma=self.hparams.train.lr_decay - ) - self.scheduler_d = torch.optim.lr_scheduler.ExponentialLR( - self.optim_d, gamma=self.hparams.train.lr_decay - ) - self.optimizers_count = 2 - self.load(reset_optimizer) - def configure_optimizers(self): return [self.optim_g, self.optim_d], [self.scheduler_g, self.scheduler_d] @@ -239,7 +240,7 @@ def log_audio_dict(self, audio_dict: dict[str, Any]) -> None: writer.add_audio( k, v, - self.trainer.fit_loop.total_batch_idx, + self.total_batch_idx, sample_rate=self.hparams.data.sampling_rate, ) @@ -291,8 +292,8 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: # generator loss LOG.debug("Calculating generator loss") - with torch.no_grad(): - y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat) + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = self.net_d(y, y_hat) + with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * self.hparams.train.c_mel loss_kl = ( @@ -353,9 +354,9 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: ) # optimizer + optim_g.zero_grad() self.manual_backward(loss_gen_all) optim_g.step() - optim_g.zero_grad() self.untoggle_optimizer(optim_g) # Discriminator @@ -377,9 +378,9 @@ def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int) -> None: ) # optimizer + optim_d.zero_grad() self.manual_backward(loss_disc_all) optim_d.step() - optim_d.zero_grad() self.untoggle_optimizer(optim_d) def validation_step(self, batch, batch_idx): @@ -399,7 +400,7 @@ def validation_step(self, batch, batch_idx): "gt/mel": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy()), } ) - if self.current_epoch == 0: + if self.current_epoch == 0 or batch_idx != 0: return utils.save_checkpoint( self.net_g,