From 922f4141f0d431dc4164ac92b0034c70991c715b Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 17:42:10 +0000 Subject: [PATCH 01/13] add a basic trainer and dataset --- e2_collate.py | 24 +++++++++++++ e2_dataset.py | 49 +++++++++++++++++++++++++ e2_trainer.py | 86 ++++++++++++++++++++++++++++++++++++++++++++ train_e2.py | 47 ++++++++++++++++++++++++ utils/compute_mel.py | 43 ++++++++++++++++++++++ 5 files changed, 249 insertions(+) create mode 100644 e2_collate.py create mode 100644 e2_dataset.py create mode 100644 e2_trainer.py create mode 100644 train_e2.py create mode 100644 utils/compute_mel.py diff --git a/e2_collate.py b/e2_collate.py new file mode 100644 index 0000000..169812a --- /dev/null +++ b/e2_collate.py @@ -0,0 +1,24 @@ +import torch +from torch.nn.utils.rnn import pad_sequence + +def collate_fn(batch): + mel_spec = [item['mel_spec'].squeeze(0) for item in batch] + mel_lengths = [item['mel_spec'].shape[-1] for item in batch] + text = [item['text'] for item in batch] + max_mel_length = max(mel_lengths) + padded_audio = [] + for item in mel_spec: + padding = (0, max_mel_length - item.size(-1)) + padded_item = torch.nn.functional.pad(item, padding, mode='constant', value=0) + padded_audio.append(padded_item) + audio = torch.stack(padded_audio) + + text_lengths = torch.LongTensor([len(item) for item in text]) + text = pad_sequence([torch.LongTensor(item) for item in text], batch_first=True) + batch_dict = { + 'mel': mel_spec, + 'mel_lengths': mel_lengths, + 'text': text, + 'text_lengths': text_lengths, + } + return batch_dict \ No newline at end of file diff --git a/e2_dataset.py b/e2_dataset.py new file mode 100644 index 0000000..c581fcf --- /dev/null +++ b/e2_dataset.py @@ -0,0 +1,49 @@ +import os +import torch +from torch.utils.data import Dataset +import pandas as pd +from pathlib import Path +import torchaudio +from utils.compute_mel import TorchMelSpectrogram +from datasets import load_dataset +from tokenizers import Tokenizer +import logging +logger = logging.getLogger(__name__) +class E2EDataset(Dataset): + def __init__(self, hf_dataset, tokenizer_path): + self.data = load_dataset(hf_dataset, split='train') + self.tokenizer = Tokenizer.from_file(tokenizer_path) + self.target_sample_rate = 22050 + self.hop_length = 256 + self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + row = self.data[index] + audio = row['audio']['array'] + logger.info(f"Audio shape: {audio.shape}") + sample_rate = row['audio']['sampling_rate'] + duration = audio.shape[-1] / sample_rate + + if duration > 20 or duration < 0.3: + logger.warning(f"Skipping due to duration out of bound: {duration}") + return self.__getitem__((index + 1) % len(self.data)) + + audio_tensor = torch.from_numpy(audio).float() + + if sample_rate != self.target_sample_rate: + resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) + audio_tensor = resampler(audio_tensor).unsqueeze(0) + mel_spec = self.mel_spectrogram(audio_tensor) + text = row['transcript'] + text = text.replace(" ", "[SPACE]") + text_tokens = self.tokenize_text(text) + + return { + 'mel_spec': mel_spec, + 'text': text_tokens, + } + def tokenize_text(self, text): + output = self.tokenizer.encode(text) + return output.ids \ No newline at end of file diff --git a/e2_trainer.py b/e2_trainer.py new file mode 100644 index 0000000..6eb992d --- /dev/null +++ b/e2_trainer.py @@ -0,0 +1,86 @@ +import torch +import torchaudio +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.nn import functional as F +from accelerate import Accelerator +from e2_collate import collate_fn +import os +import logging +from utils.compute_mel import TorchMelSpectrogram + +class E2Trainer: + def __init__(self, model, optimizer, duration_predictor=None, + checkpoint_path=None, log_file="logs.txt", + max_grad_norm=1.0, + sample_rate=22050): + self.target_sample_rate = sample_rate + self.accelerator = Accelerator(log_with="all") + self.model = model + self.duration_predictor = duration_predictor + self.optimizer = optimizer + self.checkpoint_path = checkpoint_path + self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) + self.model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + self.logger = logging.getLogger(__name__) + self.logger.setLevel(logging.INFO) + handler = logging.FileHandler(log_file) + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.max_grad_norm = max_grad_norm + + def save_checkpoint(self, step, finetune=False): + if self.checkpoint_path is None: + self.checkpoint_path = "model.pth" + checkpoint = { + 'model_state_dict': self.accelerator.unwrap_model(self.model).state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'step': step + } + torch.save(checkpoint, self.checkpoint_path) + + def load_checkpoint(self): + if self.checkpoint_path is not None and os.path.exists(self.checkpoint_path): + checkpoint = torch.load(self.checkpoint_path) + self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + return checkpoint['step'] + return 0 + + def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, num_workers=12, save_step=1000): + train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers=num_workers, pin_memory=True) + train_dataloader = self.accelerator.prepare(train_dataloader) + start_step = 0 + global_step = start_step + for epoch in range(epochs): + self.model.train() + progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs}", unit="step", disable=not self.accelerator.is_local_main_process) + for batch in progress_bar: + text_inputs = batch['text'] + text_lengths = batch['text_lengths'] + mel = batch['mel'] + mel_lengths = batch["mel_lengths"] + # duration = batch['durations'] + if self.duration_predictor is not None: + dur_loss = self.duration_predictor(mel, target_duration = duration) + masked_mel, masked_mel_hat = self.model(mel, mel_lengths, text_inputs, text_lengths) + mel_loss = torch.nn.functional.mse_loss(masked_mel, masked_mel_hat) + self.accelerator.backward(mel_loss) + + if self.max_grad_norm > 0: + self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.accelerator.is_local_main_process: + self.logger.info(f"Step {global_step+1}: E2E Mel Loss = {mel_loss.item():.4f}") + global_step += 1 + progress_bar.set_postfix(mel_loss=mel_loss.item()) + if global_step % save_step == 0: + self.save_checkpoint(global_step) + mel_loss /= len(train_dataloader) + if self.accelerator.is_local_main_process: + self.logger.info(f"Epoch {epoch+1}/{epochs} - E2E Mel Loss = {mel_loss.item():.4f}") \ No newline at end of file diff --git a/train_e2.py b/train_e2.py new file mode 100644 index 0000000..85d00c7 --- /dev/null +++ b/train_e2.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.init as init +from torch.optim import Adam +from e2_dataset import E2EDataset +from e2_tts_pytorch.e2_tts import E2TTS, DurationPredictor +from e2_trainer import E2Trainer + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +tokenizer_path = "vocab.json" + +train_dataset = E2EDataset("MushanW/GLOBE", tokenizer_path) + +duration_predictor = DurationPredictor( + transformer = dict( + dim = 512, + depth = 2, + ) +).to(device) + +e2tts = E2TTS( + duration_predictor = duration_predictor, + transformer = dict( + dim = 512, + depth = 4, + skip_connect_type = 'concat' + ), +).to(device) + + +optimizer = Adam(e2tts.parameters(), lr=1e-4) + +checkpoint_path = 'e2e.pt' +log_file = 'e2e.txt' + +trainer = E2Trainer( + e2tts, + optimizer, + checkpoint_path=checkpoint_path, + log_file=log_file +) + +epochs = 10 +batch_size = 8 +grad_accumulation_steps = 1 + +trainer.train(train_dataset, epochs, batch_size, grad_accumulation_steps, save_step=1000) \ No newline at end of file diff --git a/utils/compute_mel.py b/utils/compute_mel.py new file mode 100644 index 0000000..251ff08 --- /dev/null +++ b/utils/compute_mel.py @@ -0,0 +1,43 @@ +import torch +from torch import nn +import torchaudio +class TorchMelSpectrogram(nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=256, + win_length=1024, + n_mel_channels=80, + mel_fmin=0, + mel_fmax=8000, + sampling_rate=22050, + normalize=False, + ): + super().__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.n_mel_channels = n_mel_channels + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.sampling_rate = sampling_rate + self.mel_stft = torchaudio.transforms.MelSpectrogram( + n_fft=self.filter_length, + hop_length=self.hop_length, + win_length=self.win_length, + power=2, + normalized=normalize, + sample_rate=self.sampling_rate, + f_min=self.mel_fmin, + f_max=self.mel_fmax, + n_mels=self.n_mel_channels, + norm="slaney", + ) + def forward(self, inp): + if len(inp.shape) == 3: + inp = inp.squeeze(1) + assert len(inp.shape) == 2 + self.mel_stft = self.mel_stft.to(inp.device) + mel = self.mel_stft(inp) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel From 4bb65f17548d04b9d4e4bb66756cd8e618493141 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 18:50:36 +0000 Subject: [PATCH 02/13] working trainer without any error --- e2_collate.py | 25 ++++++++++++++----------- e2_trainer.py | 21 ++++++++++++--------- e2_tts_pytorch/e2_tts.py | 2 +- train_e2.py | 4 ++-- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/e2_collate.py b/e2_collate.py index 169812a..9cdddde 100644 --- a/e2_collate.py +++ b/e2_collate.py @@ -2,21 +2,24 @@ from torch.nn.utils.rnn import pad_sequence def collate_fn(batch): - mel_spec = [item['mel_spec'].squeeze(0) for item in batch] - mel_lengths = [item['mel_spec'].shape[-1] for item in batch] - text = [item['text'] for item in batch] - max_mel_length = max(mel_lengths) - padded_audio = [] - for item in mel_spec: - padding = (0, max_mel_length - item.size(-1)) - padded_item = torch.nn.functional.pad(item, padding, mode='constant', value=0) - padded_audio.append(padded_item) - audio = torch.stack(padded_audio) + mel_specs = [item['mel_spec'].squeeze(0) for item in batch] + mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs]) + + max_mel_length = mel_lengths.max().item() + padded_mel_specs = [] + for spec in mel_specs: + padding = (0, max_mel_length - spec.size(-1)) + padded_spec = torch.nn.functional.pad(spec, padding, mode='constant', value=0) + padded_mel_specs.append(padded_spec) + + mel_specs = torch.stack(padded_mel_specs) + text = [item['text'] for item in batch] text_lengths = torch.LongTensor([len(item) for item in text]) text = pad_sequence([torch.LongTensor(item) for item in text], batch_first=True) + batch_dict = { - 'mel': mel_spec, + 'mel': mel_specs, 'mel_lengths': mel_lengths, 'text': text, 'text_lengths': text_lengths, diff --git a/e2_trainer.py b/e2_trainer.py index 6eb992d..b0682dc 100644 --- a/e2_trainer.py +++ b/e2_trainer.py @@ -8,6 +8,7 @@ import os import logging from utils.compute_mel import TorchMelSpectrogram +from einops import rearrange class E2Trainer: def __init__(self, model, optimizer, duration_predictor=None, @@ -61,14 +62,16 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu for batch in progress_bar: text_inputs = batch['text'] text_lengths = batch['text_lengths'] - mel = batch['mel'] + mel_spec = rearrange(batch['mel'], 'b d n -> b n d') mel_lengths = batch["mel_lengths"] + print(mel_spec.shape) + print(text_inputs.shape) # duration = batch['durations'] if self.duration_predictor is not None: - dur_loss = self.duration_predictor(mel, target_duration = duration) - masked_mel, masked_mel_hat = self.model(mel, mel_lengths, text_inputs, text_lengths) - mel_loss = torch.nn.functional.mse_loss(masked_mel, masked_mel_hat) - self.accelerator.backward(mel_loss) + dur_loss = self.duration_predictor(mel_spec, target_duration = duration) + loss = self.model(mel_spec, text_inputs, lens=mel_lengths) + # mel_loss = torch.nn.functional.mse_loss(masked_mel, masked_mel_hat) + self.accelerator.backward(loss) if self.max_grad_norm > 0: self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) @@ -76,11 +79,11 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu self.optimizer.step() self.optimizer.zero_grad() if self.accelerator.is_local_main_process: - self.logger.info(f"Step {global_step+1}: E2E Mel Loss = {mel_loss.item():.4f}") + self.logger.info(f"Step {global_step+1}: E2E Loss = {loss.item():.4f}") global_step += 1 - progress_bar.set_postfix(mel_loss=mel_loss.item()) + progress_bar.set_postfix(loss=loss.item()) if global_step % save_step == 0: self.save_checkpoint(global_step) - mel_loss /= len(train_dataloader) + loss /= len(train_dataloader) if self.accelerator.is_local_main_process: - self.logger.info(f"Epoch {epoch+1}/{epochs} - E2E Mel Loss = {mel_loss.item():.4f}") \ No newline at end of file + self.logger.info(f"Epoch {epoch+1}/{epochs} - E2E Loss = {loss.item():.4f}") \ No newline at end of file diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py index f45d07b..37aa01e 100644 --- a/e2_tts_pytorch/e2_tts.py +++ b/e2_tts_pytorch/e2_tts.py @@ -307,6 +307,7 @@ def fn(t, x): def forward( self, x: Float['b n d'], + text: Int['b n'], times: Int['b'] | None = None, lens: Int['b'] | None = None, mask: Bool['b n'] | None = None, @@ -329,7 +330,6 @@ def forward( times = torch.rand((batch,), dtype = dtype, device = self.device) # transformer and prediction head - x = self.transformer( x, times = times, diff --git a/train_e2.py b/train_e2.py index 85d00c7..994d281 100644 --- a/train_e2.py +++ b/train_e2.py @@ -7,7 +7,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -tokenizer_path = "vocab.json" +tokenizer_path = "/home/azureuser/xtts/assets/vocab.json" train_dataset = E2EDataset("MushanW/GLOBE", tokenizer_path) @@ -21,7 +21,7 @@ e2tts = E2TTS( duration_predictor = duration_predictor, transformer = dict( - dim = 512, + dim = 80, depth = 4, skip_connect_type = 'concat' ), From 62dff50d0a236e1143ab8b2352db5465db2dd49a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 18:51:19 +0000 Subject: [PATCH 03/13] revert path --- train_e2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_e2.py b/train_e2.py index 994d281..d5d42ed 100644 --- a/train_e2.py +++ b/train_e2.py @@ -7,7 +7,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -tokenizer_path = "/home/azureuser/xtts/assets/vocab.json" +tokenizer_path = "vocab.json" train_dataset = E2EDataset("MushanW/GLOBE", tokenizer_path) From 4da3db6d459e5931f130d2c88e2dc67d1167fef7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 18:57:37 +0000 Subject: [PATCH 04/13] clean up and pip --- e2_dataset.py | 3 --- e2_trainer.py | 1 - pyproject.toml | 4 ++++ train_e2.py | 1 - 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/e2_dataset.py b/e2_dataset.py index c581fcf..04ff64a 100644 --- a/e2_dataset.py +++ b/e2_dataset.py @@ -1,8 +1,5 @@ -import os import torch from torch.utils.data import Dataset -import pandas as pd -from pathlib import Path import torchaudio from utils.compute_mel import TorchMelSpectrogram from datasets import load_dataset diff --git a/e2_trainer.py b/e2_trainer.py index b0682dc..e1120bd 100644 --- a/e2_trainer.py +++ b/e2_trainer.py @@ -1,5 +1,4 @@ import torch -import torchaudio from torch.utils.data import DataLoader from tqdm import tqdm from torch.nn import functional as F diff --git a/pyproject.toml b/pyproject.toml index 7f4920a..482b48f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,10 @@ dependencies = [ 'torch>=2.0', 'torchdiffeq', 'x-transformers>=1.31.12' + 'accelerate>=0.32.1' + 'datasets>=2.20.0' + 'tqdm>=4.65.0' + 'tokenizers>=0.19.1' ] [project.urls] diff --git a/train_e2.py b/train_e2.py index d5d42ed..0adfd0f 100644 --- a/train_e2.py +++ b/train_e2.py @@ -1,5 +1,4 @@ import torch -import torch.nn.init as init from torch.optim import Adam from e2_dataset import E2EDataset from e2_tts_pytorch.e2_tts import E2TTS, DurationPredictor From 23849ea6f6bc5c8b05e1394a8baa0cc3f298b44a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 19:00:16 +0000 Subject: [PATCH 05/13] use einops --- e2_dataset.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/e2_dataset.py b/e2_dataset.py index 04ff64a..534dc11 100644 --- a/e2_dataset.py +++ b/e2_dataset.py @@ -5,7 +5,10 @@ from datasets import load_dataset from tokenizers import Tokenizer import logging +from einops import rearrange, reduce + logger = logging.getLogger(__name__) + class E2EDataset(Dataset): def __init__(self, hf_dataset, tokenizer_path): self.data = load_dataset(hf_dataset, split='train') @@ -13,9 +16,10 @@ def __init__(self, hf_dataset, tokenizer_path): self.target_sample_rate = 22050 self.hop_length = 256 self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) + def __len__(self): return len(self.data) - + def __getitem__(self, index): row = self.data[index] audio = row['audio']['array'] @@ -31,8 +35,14 @@ def __getitem__(self, index): if sample_rate != self.target_sample_rate: resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate) - audio_tensor = resampler(audio_tensor).unsqueeze(0) + audio_tensor = resampler(audio_tensor) + + audio_tensor = rearrange(audio_tensor, 't -> 1 t') + mel_spec = self.mel_spectrogram(audio_tensor) + + mel_spec = rearrange(mel_spec, '1 d t -> d t') + text = row['transcript'] text = text.replace(" ", "[SPACE]") text_tokens = self.tokenize_text(text) @@ -41,6 +51,7 @@ def __getitem__(self, index): 'mel_spec': mel_spec, 'text': text_tokens, } + def tokenize_text(self, text): output = self.tokenizer.encode(text) return output.ids \ No newline at end of file From df7c47796cd0ebd5bd5e0784ca4d72d8ce7c79eb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 12 Jul 2024 19:02:23 +0000 Subject: [PATCH 06/13] clean up --- e2_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/e2_trainer.py b/e2_trainer.py index e1120bd..b0f2bd3 100644 --- a/e2_trainer.py +++ b/e2_trainer.py @@ -63,8 +63,6 @@ def train(self, train_dataset, epochs, batch_size, grad_accumulation_steps=1, nu text_lengths = batch['text_lengths'] mel_spec = rearrange(batch['mel'], 'b d n -> b n d') mel_lengths = batch["mel_lengths"] - print(mel_spec.shape) - print(text_inputs.shape) # duration = batch['durations'] if self.duration_predictor is not None: dur_loss = self.duration_predictor(mel_spec, target_duration = duration) From 8d1f133546fa4f61a61a3f9dcac8e146e1230712 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 02:47:21 +0000 Subject: [PATCH 07/13] fixes --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df207fc..7908ace 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,9 +29,9 @@ dependencies = [ 'jaxtyping', 'torch>=2.0', 'torchdiffeq', - 'x-transformers>=1.31.12' - 'accelerate>=0.32.1' - 'datasets>=2.20.0' + 'x-transformers>=1.31.12', + 'accelerate>=0.32.1', + 'datasets>=2.20.0', 'tqdm>=4.65.0', 'tokenizers>=0.19.1' ] From 397fb0801bdd9975f03542282e1fbb56c8e402a7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:07:46 +0000 Subject: [PATCH 08/13] fix dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7908ace..5378d03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ 'accelerate>=0.32.1', 'datasets>=2.20.0', 'tqdm>=4.65.0', - 'tokenizers>=0.19.1' + 'tokenizers>=0.19.1', + 'torchaudio>=2.3.1' ] [project.urls] From bc7eea8f849cfe44a7f2455be6a90b4c5a0008b4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:17:15 +0000 Subject: [PATCH 09/13] replace bpe tokenizer with char tokenizer --- e2_tts_pytorch/dataset/e2_dataset.py | 29 ++++++++++++++++++++-------- pyproject.toml | 1 - train_e2.py | 5 +---- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/e2_tts_pytorch/dataset/e2_dataset.py b/e2_tts_pytorch/dataset/e2_dataset.py index 93c1a90..50a07df 100644 --- a/e2_tts_pytorch/dataset/e2_dataset.py +++ b/e2_tts_pytorch/dataset/e2_dataset.py @@ -3,20 +3,24 @@ import torchaudio from e2_tts_pytorch.utils.compute_mel import TorchMelSpectrogram from datasets import load_dataset -from tokenizers import Tokenizer import logging from einops import rearrange, reduce logger = logging.getLogger(__name__) class E2EDataset(Dataset): - def __init__(self, hf_dataset, tokenizer_path): + def __init__(self, hf_dataset): self.data = load_dataset(hf_dataset, split='train') - self.tokenizer = Tokenizer.from_file(tokenizer_path) self.target_sample_rate = 22050 self.hop_length = 256 self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) - + + self.char_set = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?-:;'\"()[] ") + self.char_to_id = {char: i for i, char in enumerate(sorted(self.char_set))} + self.char_to_id[''] = len(self.char_to_id) # Unknown token + self.char_to_id[''] = len(self.char_to_id) # Start of sequence token + self.char_to_id[''] = len(self.char_to_id) # End of sequence token + self.id_to_char = {i: char for char, i in self.char_to_id.items()} def __len__(self): return len(self.data) @@ -45,13 +49,22 @@ def __getitem__(self, index): text = row['transcript'] text = text.replace(" ", "[SPACE]") - text_tokens = self.tokenize_text(text) + text_tokens = self.encode(text) return { 'mel_spec': mel_spec, 'text': text_tokens, } - def tokenize_text(self, text): - output = self.tokenizer.encode(text) - return output.ids \ No newline at end of file + def encode(self, text): + tokens = [self.char_to_id['']] + for char in text: + if char in self.char_to_id: + tokens.append(self.char_to_id[char]) + else: + tokens.append(self.char_to_id['']) + tokens.append(self.char_to_id['']) + return torch.tensor(tokens, dtype=torch.long) + + def decode(self, token_ids): + return ''.join([self.id_to_char[id.item()] for id in token_ids]) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5378d03..d21002a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ dependencies = [ 'accelerate>=0.32.1', 'datasets>=2.20.0', 'tqdm>=4.65.0', - 'tokenizers>=0.19.1', 'torchaudio>=2.3.1' ] diff --git a/train_e2.py b/train_e2.py index 8a8e71b..f83bba1 100644 --- a/train_e2.py +++ b/train_e2.py @@ -6,10 +6,7 @@ from e2_tts_pytorch.trainer.e2_trainer import E2Trainer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -tokenizer_path = "/home/azureuser/e2/e2-tts-pytorch/e2_tts_pytorch/assets/vocab.json" - -train_dataset = E2EDataset("MushanW/GLOBE", tokenizer_path) +train_dataset = E2EDataset("MushanW/GLOBE") duration_predictor = DurationPredictor( transformer = dict( From 63bfae9a6b883a8b2bc422c27f318530ee36c82e Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:36:16 +0000 Subject: [PATCH 10/13] remove tokenization --- e2_tts_pytorch/assets/vocab.json | 1 - e2_tts_pytorch/dataset/e2_collate.py | 2 -- e2_tts_pytorch/dataset/e2_dataset.py | 19 ++----------------- 3 files changed, 2 insertions(+), 20 deletions(-) delete mode 100644 e2_tts_pytorch/assets/vocab.json diff --git a/e2_tts_pytorch/assets/vocab.json b/e2_tts_pytorch/assets/vocab.json deleted file mode 100644 index a128f27..0000000 --- a/e2_tts_pytorch/assets/vocab.json +++ /dev/null @@ -1 +0,0 @@ -{"version":"1.0","truncation":null,"padding":null,"added_tokens":[{"id":0,"special":true,"content":"[STOP]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":1,"special":true,"content":"[UNK]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false},{"id":2,"special":true,"content":"[SPACE]","single_word":false,"lstrip":false,"rstrip":false,"normalized":false}],"normalizer":null,"pre_tokenizer":{"type":"Whitespace"},"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":"[UNK]","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"vocab":{"[STOP]":0,"[UNK]":1,"[SPACE]":2,"!":3,"'":4,"(":5,")":6,",":7,"-":8,".":9,"/":10,":":11,";":12,"?":13,"a":14,"b":15,"c":16,"d":17,"e":18,"f":19,"g":20,"h":21,"i":22,"j":23,"k":24,"l":25,"m":26,"n":27,"o":28,"p":29,"q":30,"r":31,"s":32,"t":33,"u":34,"v":35,"w":36,"x":37,"y":38,"z":39,"th":40,"in":41,"the":42,"an":43,"er":44,"ou":45,"re":46,"on":47,"at":48,"ed":49,"en":50,"to":51,"ing":52,"and":53,"is":54,"as":55,"al":56,"or":57,"of":58,"ar":59,"it":60,"es":61,"he":62,"st":63,"le":64,"om":65,"se":66,"be":67,"ad":68,"ow":69,"ly":70,"ch":71,"wh":72,"that":73,"you":74,"li":75,"ve":76,"ac":77,"ti":78,"ld":79,"me":80,"was":81,"gh":82,"id":83,"ll":84,"wi":85,"ent":86,"for":87,"ay":88,"ro":89,"ver":90,"ic":91,"her":92,"ke":93,"his":94,"no":95,"ut":96,"un":97,"ir":98,"lo":99,"we":100,"ri":101,"ha":102,"with":103,"ght":104,"out":105,"im":106,"ion":107,"all":108,"ab":109,"one":110,"ne":111,"ge":112,"ould":113,"ter":114,"mo":115,"had":116,"ce":117,"she":118,"go":119,"sh":120,"ur":121,"am":122,"so":123,"pe":124,"my":125,"de":126,"are":127,"but":128,"ome":129,"fr":130,"ther":131,"fe":132,"su":133,"do":134,"con":135,"te":136,"ain":137,"ere":138,"po":139,"if":140,"they":141,"us":142,"ag":143,"tr":144,"now":145,"oun":146,"this":147,"have":148,"not":149,"sa":150,"il":151,"up":152,"thing":153,"from":154,"ap":155,"him":156,"ack":157,"ation":158,"ant":159,"our":160,"op":161,"like":162,"ust":163,"ess":164,"bo":165,"ok":166,"ul":167,"ind":168,"ex":169,"com":170,"some":171,"there":172,"ers":173,"co":174,"res":175,"man":176,"ard":177,"pl":178,"wor":179,"way":180,"tion":181,"fo":182,"ca":183,"were":184,"by":185,"ate":186,"pro":187,"ted":188,"ound":189,"own":190,"would":191,"ts":192,"what":193,"qu":194,"ally":195,"ight":196,"ck":197,"gr":198,"when":199,"ven":200,"can":201,"ough":202,"ine":203,"end":204,"per":205,"ous":206,"od":207,"ide":208,"know":209,"ty":210,"very":211,"si":212,"ak":213,"who":214,"about":215,"ill":216,"them":217,"est":218,"red":219,"ye":220,"could":221,"ong":222,"your":223,"their":224,"em":225,"just":226,"other":227,"into":228,"any":229,"whi":230,"um":231,"tw":232,"ast":233,"der":234,"did":235,"ie":236,"been":237,"ace":238,"ink":239,"ity":240,"back":241,"ting":242,"br":243,"more":244,"ake":245,"pp":246,"then":247,"sp":248,"el":249,"use":250,"bl":251,"said":252,"over":253,"get":254},"merges":["t h","i n","th e","a n","e r","o u","r e","o n","a t","e d","e n","t o","in g","an d","i s","a s","a l","o r","o f","a r","i t","e s","h e","s t","l e","o m","s e","b e","a d","o w","l y","c h","w h","th at","y ou","l i","v e","a c","t i","l d","m e","w as","g h","i d","l l","w i","en t","f or","a y","r o","v er","i c","h er","k e","h is","n o","u t","u n","i r","l o","w e","r i","h a","wi th","gh t","ou t","i m","i on","al l","a b","on e","n e","g e","ou ld","t er","m o","h ad","c e","s he","g o","s h","u r","a m","s o","p e","m y","d e","a re","b ut","om e","f r","the r","f e","s u","d o","c on","t e","a in","er e","p o","i f","the y","u s","a g","t r","n ow","ou n","th is","ha ve","no t","s a","i l","u p","th ing","fr om","a p","h im","ac k","at ion","an t","ou r","o p","li ke","u st","es s","b o","o k","u l","in d","e x","c om","s ome","the re","er s","c o","re s","m an","ar d","p l","w or","w ay","ti on","f o","c a","w ere","b y","at e","p ro","t ed","oun d","ow n","w ould","t s","wh at","q u","al ly","i ght","c k","g r","wh en","v en","c an","ou gh","in e","en d","p er","ou s","o d","id e","k now","t y","ver y","s i","a k","wh o","ab out","i ll","the m","es t","re d","y e","c ould","on g","you r","the ir","e m","j ust","o ther","in to","an y","wh i","u m","t w","as t","d er","d id","i e","be en","ac e","in k","it y","b ack","t ing","b r","mo re","a ke","p p","the n","s p","e l","u se","b l","sa id","o ver","ge t"]}} \ No newline at end of file diff --git a/e2_tts_pytorch/dataset/e2_collate.py b/e2_tts_pytorch/dataset/e2_collate.py index 9cdddde..ef7c31d 100644 --- a/e2_tts_pytorch/dataset/e2_collate.py +++ b/e2_tts_pytorch/dataset/e2_collate.py @@ -16,8 +16,6 @@ def collate_fn(batch): text = [item['text'] for item in batch] text_lengths = torch.LongTensor([len(item) for item in text]) - text = pad_sequence([torch.LongTensor(item) for item in text], batch_first=True) - batch_dict = { 'mel': mel_specs, 'mel_lengths': mel_lengths, diff --git a/e2_tts_pytorch/dataset/e2_dataset.py b/e2_tts_pytorch/dataset/e2_dataset.py index 50a07df..170de49 100644 --- a/e2_tts_pytorch/dataset/e2_dataset.py +++ b/e2_tts_pytorch/dataset/e2_dataset.py @@ -48,23 +48,8 @@ def __getitem__(self, index): mel_spec = rearrange(mel_spec, '1 d t -> d t') text = row['transcript'] - text = text.replace(" ", "[SPACE]") - text_tokens = self.encode(text) return { 'mel_spec': mel_spec, - 'text': text_tokens, - } - - def encode(self, text): - tokens = [self.char_to_id['']] - for char in text: - if char in self.char_to_id: - tokens.append(self.char_to_id[char]) - else: - tokens.append(self.char_to_id['']) - tokens.append(self.char_to_id['']) - return torch.tensor(tokens, dtype=torch.long) - - def decode(self, token_ids): - return ''.join([self.id_to_char[id.item()] for id in token_ids]) \ No newline at end of file + 'text': text, + } \ No newline at end of file From 34eec6a422adedff563c14cb7649561fbde864bf Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:36:36 +0000 Subject: [PATCH 11/13] clean up --- e2_tts_pytorch/dataset/e2_dataset.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/e2_tts_pytorch/dataset/e2_dataset.py b/e2_tts_pytorch/dataset/e2_dataset.py index 170de49..123839f 100644 --- a/e2_tts_pytorch/dataset/e2_dataset.py +++ b/e2_tts_pytorch/dataset/e2_dataset.py @@ -14,13 +14,7 @@ def __init__(self, hf_dataset): self.target_sample_rate = 22050 self.hop_length = 256 self.mel_spectrogram = TorchMelSpectrogram(sampling_rate=self.target_sample_rate) - - self.char_set = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789.,!?-:;'\"()[] ") - self.char_to_id = {char: i for i, char in enumerate(sorted(self.char_set))} - self.char_to_id[''] = len(self.char_to_id) # Unknown token - self.char_to_id[''] = len(self.char_to_id) # Start of sequence token - self.char_to_id[''] = len(self.char_to_id) # End of sequence token - self.id_to_char = {i: char for char, i in self.char_to_id.items()} + def __len__(self): return len(self.data) From ce40cc6bf168927b0719fcb51bad8b6e05328083 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:46:24 +0000 Subject: [PATCH 12/13] remove selfs --- e2_tts_pytorch/utils/compute_mel.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/e2_tts_pytorch/utils/compute_mel.py b/e2_tts_pytorch/utils/compute_mel.py index 251ff08..4074f53 100644 --- a/e2_tts_pytorch/utils/compute_mel.py +++ b/e2_tts_pytorch/utils/compute_mel.py @@ -14,23 +14,16 @@ def __init__( normalize=False, ): super().__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length - self.n_mel_channels = n_mel_channels - self.mel_fmin = mel_fmin - self.mel_fmax = mel_fmax - self.sampling_rate = sampling_rate self.mel_stft = torchaudio.transforms.MelSpectrogram( - n_fft=self.filter_length, - hop_length=self.hop_length, - win_length=self.win_length, + n_fft=filter_length, + hop_length=hop_length, + win_length=win_length, power=2, normalized=normalize, - sample_rate=self.sampling_rate, - f_min=self.mel_fmin, - f_max=self.mel_fmax, - n_mels=self.n_mel_channels, + sample_rate=sampling_rate, + f_min=mel_fmin, + f_max=mel_fmax, + n_mels=n_mel_channels, norm="slaney", ) def forward(self, inp): From ca1d757c3fc881628bb134decad626755b17719c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 14 Jul 2024 14:47:30 +0000 Subject: [PATCH 13/13] remove init --- train_e2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/train_e2.py b/train_e2.py index f83bba1..05ee256 100644 --- a/train_e2.py +++ b/train_e2.py @@ -1,5 +1,4 @@ import torch -import torch.nn.init as init from torch.optim import Adam from e2_tts_pytorch.dataset.e2_dataset import E2EDataset from e2_tts_pytorch.e2_tts import E2TTS, DurationPredictor