From 626d08ff616c773c3be1ab7cd468ab2d93fbb22c Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 23 Nov 2024 07:44:26 -0800 Subject: [PATCH] add learned value residual mixing + add adopt as default --- e2_tts_pytorch/e2_tts.py | 4 +++- e2_tts_pytorch/trainer.py | 19 +++++++++++++++---- pyproject.toml | 5 +++-- train_example.py | 3 --- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index e716cc9..3fee3e2 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -560,6 +560,8 @@ def __init__( ) for ind in range(depth): + is_first_block = ind == 0 + is_later_half = ind >= (depth // 2) has_text = ind < text_depth @@ -596,7 +598,7 @@ def __init__( text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size) text_attn_norm = RMSNorm(dim_text) - text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs) + text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, **attn_kwargs) text_ff_norm = RMSNorm(dim_text) text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs) diff --git a/e2_tts_pytorch/trainer.py b/e2_tts_pytorch/trainer.py index bad3216..6cb8014 100644 --- a/e2_tts_pytorch/trainer.py +++ b/e2_tts_pytorch/trainer.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from torch.utils.tensorboard import SummaryWriter from torch.optim.lr_scheduler import LinearLR, SequentialLR @@ -19,6 +20,8 @@ from accelerate import Accelerator from accelerate.utils import DistributedDataParallelKwargs +from adam_atan2_pytorch.adopt import Adopt + from ema_pytorch import EMA from loguru import logger @@ -133,9 +136,10 @@ class E2Trainer: def __init__( self, model: E2TTS, - optimizer, - num_warmup_steps=20000, - grad_accumulation_steps=1, + optimizer: Optimizer | None = None, + learning_rate = 7.5e-5, + num_warmup_steps = 20000, + grad_accumulation_steps = 1, duration_predictor: DurationPredictor | None = None, checkpoint_path = None, log_file = "logs.txt", @@ -172,9 +176,15 @@ def __init__( self.use_switch_ema = use_switch_ema self.duration_predictor = duration_predictor + + # optimizer + + if not exists(optimizer): + optimizer = Adopt(model.parameters(), lr = learning_rate) + self.optimizer = optimizer + self.num_warmup_steps = num_warmup_steps - self.checkpoint_path = default(checkpoint_path, 'model.pth') self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate) self.ema_model, self.model, self.optimizer = self.accelerator.prepare( @@ -182,6 +192,7 @@ def __init__( ) self.max_grad_norm = max_grad_norm + self.checkpoint_path = default(checkpoint_path, 'model.pth') self.writer = SummaryWriter(log_dir=tensorboard_log_dir) @property diff --git a/pyproject.toml b/pyproject.toml index 84d4c49..c877154 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "e2-tts-pytorch" -version = "1.5.2" +version = "1.6.0" description = "E2-TTS in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -25,6 +25,7 @@ classifiers=[ dependencies = [ 'accelerate>=0.33.0', + 'adam-atan2-pytorch>=0.1.12', 'beartype', 'einops>=0.8.0', 'einx>=0.3.0', @@ -39,7 +40,7 @@ dependencies = [ 'torchaudio>=2.3.1', 'tqdm>=4.65.0', 'vocos', - 'x-transformers>=1.42.4', + 'x-transformers>=1.42.16', ] [project.urls] diff --git a/train_example.py b/train_example.py index b88e53b..5a59070 100644 --- a/train_example.py +++ b/train_example.py @@ -1,7 +1,6 @@ import torch from e2_tts_pytorch import E2TTS, DurationPredictor -from torch.optim import Adam from datasets import load_dataset from e2_tts_pytorch.trainer import ( @@ -26,8 +25,6 @@ train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"]) -optimizer = Adam(e2tts.parameters(), lr=7.5e-5) - trainer = E2Trainer( e2tts, optimizer,