Skip to content

Commit

Permalink
add learned value residual mixing + add adopt as default
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 23, 2024
1 parent cedb3c4 commit 626d08f
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 10 deletions.
4 changes: 3 additions & 1 deletion e2_tts_pytorch/e2_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ def __init__(
)

for ind in range(depth):
is_first_block = ind == 0

is_later_half = ind >= (depth // 2)
has_text = ind < text_depth

Expand Down Expand Up @@ -596,7 +598,7 @@ def __init__(
text_conv = DepthwiseConv(dim_text, kernel_size = kernel_size)

text_attn_norm = RMSNorm(dim_text)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, **attn_kwargs)
text_attn = Attention(dim = dim_text, heads = text_heads, dim_head = text_dim_head, dropout = dropout, learned_value_residual_mix = not is_first_block, **attn_kwargs)

text_ff_norm = RMSNorm(dim_text)
text_ff = FeedForward(dim = dim_text, glu = True, mult = text_ff_mult, dropout = dropout, **ff_kwargs)
Expand Down
19 changes: 15 additions & 4 deletions e2_tts_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import LinearLR, SequentialLR
Expand All @@ -19,6 +20,8 @@
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

from adam_atan2_pytorch.adopt import Adopt

from ema_pytorch import EMA

from loguru import logger
Expand Down Expand Up @@ -133,9 +136,10 @@ class E2Trainer:
def __init__(
self,
model: E2TTS,
optimizer,
num_warmup_steps=20000,
grad_accumulation_steps=1,
optimizer: Optimizer | None = None,
learning_rate = 7.5e-5,
num_warmup_steps = 20000,
grad_accumulation_steps = 1,
duration_predictor: DurationPredictor | None = None,
checkpoint_path = None,
log_file = "logs.txt",
Expand Down Expand Up @@ -172,16 +176,23 @@ def __init__(
self.use_switch_ema = use_switch_ema

self.duration_predictor = duration_predictor

# optimizer

if not exists(optimizer):
optimizer = Adopt(model.parameters(), lr = learning_rate)

self.optimizer = optimizer

self.num_warmup_steps = num_warmup_steps
self.checkpoint_path = default(checkpoint_path, 'model.pth')
self.mel_spectrogram = MelSpec(sampling_rate=self.target_sample_rate)

self.ema_model, self.model, self.optimizer = self.accelerator.prepare(
self.ema_model, self.model, self.optimizer
)
self.max_grad_norm = max_grad_norm

self.checkpoint_path = default(checkpoint_path, 'model.pth')
self.writer = SummaryWriter(log_dir=tensorboard_log_dir)

@property
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "e2-tts-pytorch"
version = "1.5.2"
version = "1.6.0"
description = "E2-TTS in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand All @@ -25,6 +25,7 @@ classifiers=[

dependencies = [
'accelerate>=0.33.0',
'adam-atan2-pytorch>=0.1.12',
'beartype',
'einops>=0.8.0',
'einx>=0.3.0',
Expand All @@ -39,7 +40,7 @@ dependencies = [
'torchaudio>=2.3.1',
'tqdm>=4.65.0',
'vocos',
'x-transformers>=1.42.4',
'x-transformers>=1.42.16',
]

[project.urls]
Expand Down
3 changes: 0 additions & 3 deletions train_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from e2_tts_pytorch import E2TTS, DurationPredictor

from torch.optim import Adam
from datasets import load_dataset

from e2_tts_pytorch.trainer import (
Expand All @@ -26,8 +25,6 @@

train_dataset = HFDataset(load_dataset("MushanW/GLOBE")["train"])

optimizer = Adam(e2tts.parameters(), lr=7.5e-5)

trainer = E2Trainer(
e2tts,
optimizer,
Expand Down

0 comments on commit 626d08f

Please sign in to comment.