Skip to content

Commit

Permalink
add a paper that claims a free lunch, behind a flag
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 16, 2024
1 parent 195ef0d commit cdf121c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,14 @@ sampled = e2tts.sample(mel[:, :5], text = text)
url = {https://api.semanticscholar.org/CorpusID:270878436}
}
```

```bibtex
@article{Li2024SwitchEA,
title = {Switch EMA: A Free Lunch for Better Flatness and Sharpness},
author = {Siyuan Li and Zicheng Liu and Juanxi Tian and Ge Wang and Zedong Wang and Weiyang Jin and Di Wu and Cheng Tan and Tao Lin and Yang Liu and Baigui Sun and Stan Z. Li},
journal = {ArXiv},
year = {2024},
volume = {abs/2402.09240},
url = {https://api.semanticscholar.org/CorpusID:267657558}
}
```
10 changes: 8 additions & 2 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(
sample_rate = 22050,
tensorboard_log_dir = 'runs/e2_tts_experiment',
accelerate_kwargs: dict = dict(),
ema_kwargs: dict = dict()
ema_kwargs: dict = dict(),
use_switch_ema = False
):
logger.add(log_file)

Expand All @@ -168,6 +169,8 @@ def __init__(
**ema_kwargs
)

self.use_switch_ema = use_switch_ema

self.duration_predictor = duration_predictor
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
Expand Down Expand Up @@ -283,5 +286,8 @@ def train(self, train_dataset, epochs, batch_size, num_workers=12, save_step=100
if self.accelerator.is_local_main_process:
logger.info(f"epoch {epoch+1}/{epochs} - average loss = {epoch_loss:.4f}")
self.writer.add_scalar('epoch average loss', epoch_loss, epoch)


if self.use_switch_ema:
self.ema_model.update_model_with_ema()

self.writer.close()

0 comments on commit cdf121c

Please sign in to comment.