Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add core code of valle #4

Merged
merged 11 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions bins/tts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,18 @@
from argparse import ArgumentParser
import os

# from models.base.base_inference import BaseInference, base_parser
from models.tts.fastspeech2.fs2_inference import FastSpeech2Inference
from models.tts.vits.vits_inference import VitsInference
from utils.util import save_config, load_model_config, load_config
from utils.io import save_audio
from processors.acoustic_extractor import denorm_for_pred_mels
import numpy as np
from models.tts.valle.valle_inference import VALLEInference
from utils.util import load_config
import torch


def build_inference(args, cfg):
supported_inference = {
"FastSpeech2": FastSpeech2Inference,
"VITS": VitsInference,
"VALLE": VALLEInference,
}

inference_class = supported_inference[cfg.model_type]
Expand Down Expand Up @@ -53,7 +51,7 @@ def build_parser():
"--dataset",
type=str,
help="convert from the source data",
default="LJSpeech",
default=None,
)
parser.add_argument(
"--testing_set",
Expand All @@ -75,9 +73,9 @@ def build_parser():
)
parser.add_argument(
"--text",
help="Text to be synthesized",
help="Text",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'Text to be synthesized' is more informative than 'Text'

type=str,
default="Text to be synthesized.",
default="",
)
parser.add_argument(
"--vocoder_dir",
Expand All @@ -89,13 +87,15 @@ def build_parser():
parser.add_argument(
"--acoustics_dir",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Acoustic model checkpoint directory. If a directory is given, "
"search for the latest checkpoint dir in the directory. If a specific "
"checkpoint dir is given, directly load the checkpoint.",
Expand Down Expand Up @@ -139,13 +139,12 @@ def build_parser():
)
return parser


def main():
# Parse arguments
args = build_parser().parse_args()
# args, infer_type = formulate_parser(args)
print("args: ", args)

parser = build_parser()
VALLEInference.add_arguments(parser)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like VALLEInference is used no matter what type of the model.

args = parser.parse_args()
# Parse config
cfg = load_config(args.config)

Expand Down
23 changes: 21 additions & 2 deletions bins/tts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ def extract_acoustic_features(dataset, output_path, cfg, n_workers=1):
cfg (dict): dictionary that stores configurations
n_workers (int, optional): num of processes to extract features in parallel. Defaults to 1.
"""
types = ["train", "test"] if "eval" not in dataset else ["test"]
# types = ["train", "test"] if "eval" not in dataset else ["test"]
types = list()
types.append((cfg.preprocess.train_file).split('.')[0])
types.append((cfg.preprocess.valid_file).split('.')[0])
if 'test' not in types:
types.append('test')
if "eval" in dataset:
types = ["test"]

metadata = []
for dataset_type in types:
dataset_output = os.path.join(output_path, dataset)
Expand All @@ -53,10 +61,20 @@ def extract_content_features(dataset, output_path, cfg, num_workers=1):
output_path (str): directory that stores train, test and feature files of datasets
cfg (dict): dictionary that stores configurations
"""
types = ["train", "test"] if "eval" not in dataset else ["test"]
# types = ["train", "test"] if "eval" not in dataset else ["test"]

types = list()
types.append((cfg.preprocess.train_file).split('.')[0])
types.append((cfg.preprocess.valid_file).split('.')[0])
if 'test' not in types:
types.append('test')
if "eval" in dataset:
types = ["test"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeating lines: 32 - 39


metadata = []
for dataset_type in types:
dataset_output = os.path.join(output_path, dataset)
# dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type))
dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicating line 78?

with open(dataset_file, "r") as f:
metadata.extend(json.load(f))
Expand Down Expand Up @@ -87,6 +105,7 @@ def preprocess(cfg, args):
prepare_align(
dataset, cfg.dataset_path[dataset], cfg.preprocess, output_path
)

preprocess_dataset(
dataset,
cfg.dataset_path[dataset],
Expand Down
21 changes: 19 additions & 2 deletions bins/tts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

from models.tts.fastspeech2.fs2_trainer import FastSpeech2Trainer
from models.tts.vits.vits_trainer import VITSTrainer
from models.tts.valle.valle_trainer import VALLETrainer
from utils.util import load_config


def build_trainer(args, cfg):
supported_trainer = {
"FastSpeech2": FastSpeech2Trainer,
"VITS": VITSTrainer,
"VALLE": VALLETrainer,
}

trainer_class = supported_trainer[cfg.model_type]
Expand Down Expand Up @@ -56,6 +58,20 @@ def main():
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument(
"--resume_type",
type=str,
default="resume",
help="Resume training or finetuning.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Checkpoint for resume training or finetuning.",
)

VALLETrainer.add_arguments(parser)
args = parser.parse_args()
cfg = load_config(args.config)

Expand All @@ -77,12 +93,13 @@ def main():
new_datasets_list.extend(filter(None, new_datasets))
cfg.dataset.extend(new_datasets_list)

# CUDA settings
# # CUDA settings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to add one more '#'

cuda_relevant()

# Build trainer
trainer = build_trainer(args, cfg)

torch.set_num_threads(1)
torch.set_num_interop_threads(1)
trainer.train_loop()


Expand Down
52 changes: 52 additions & 0 deletions config/valle.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"base_config": "config/base.json",
"model_type": "VALLE",
"task_type": "tts",
"dataset": [
"libritts"
],
"preprocess": {
"extract_phoneme": true,
"text_extractor": "espeak", // phoneme extractor: espeak, pypinyin, pypinyin_initials_finals or lexicon
"extract_acoustic_token": true,
"acoustic_token_extractor": "Encodec", // acoustic token extractor: encodec, dac(todo)
"acoustic_token_dir": "acoutic_tokens",
"use_text": false,
"use_phone": true,
"use_acoustic_token": true,
"symbols_dict": "symbols.dict",
"min_duration": 0.5, // the duration lowerbound to filter the audio with duration < min_duration
"max_duration": 14, // the duration uperbound to filter the audio with duration > max_duration.
"sampling_rate": 24000,
},
"model": {
"text_token_num": 512,
"audio_token_num": 1024,
"decoder_dim": 1024, // embedding dimension of the decoder model
"nhead": 16, // number of attention heads in the decoder layers
"num_decoder_layers": 12, // number of decoder layers
"norm_first": true, // pre or post Normalization.
"add_prenet": false, // whether add PreNet after Inputs
"prefix_mode": 0, // mode for how to prefix VALL-E NAR Decoder, 0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance
"share_embedding": true, // share the parameters of the output projection layer with the parameters of the acoustic embedding
"nar_scale_factor": 1, // model scale factor which will be assigned different meanings in different models
"prepend_bos": false, // whether prepend <BOS> to the acoustic tokens -> AR Decoder inputs
"num_quantizers": 8, // numbert of the audio quantization layers
// "scaling_xformers": false, // Apply Reworked Conformer scaling on Transformers
},
"train": {
"ddp": false,
"train_stage": 1, // 0: train all modules, For VALL_E, support 1: AR Decoder 2: NAR Decoder(s)
"max_epoch": 20,
"optimizer": "ScaledAdam",
"scheduler": "Eden",
"warmup_steps": 200, // number of steps that affects how rapidly the learning rate decreases
"base_lr": 0.05, // base learning rate."
"valid_interval": 1000,
"log_epoch_step": 1000,
"save_checkpoint_stride": [
1,
1
]
}
}
139 changes: 139 additions & 0 deletions egs/tts/VALLE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# VALL-E Recipe

In this recipe, we will show how to train [VALL-E](https://arxiv.org/abs/2301.02111) using Amphion's infrastructure. VALL-E is a zero-shot TTS architecture that uses a neural codec language model with discrete codes.

There are four stages in total:

1. Data preparation
2. Features extraction
3. Training
4. Inference

> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
> ```bash
> cd Amphion
> ```

## 1. Data Preparation

### Dataset Download
You can use the commonly used TTS dataset to train VALL-E model, e.g., LibriTTS, etc. We strongly recommend you use LibriTTS to train VALL-E model for the first time. How to download dataset is detailed [here](../../datasets/README.md).

### Configuration

After downloading the dataset, you can set the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.

```json
"dataset": [
"libritts",
],
"dataset_path": {
// TODO: Fill in your dataset path
"libritts": "[LibriTTS dataset path]",
},
```

## 2. Features Extraction

### Configuration

Specify the `processed_dir` and the `log_dir` and for saving the processed data and the checkpoints in `exp_config.json`:

```json
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
"log_dir": "ckpts/tts",
"preprocess": {
// TODO: Fill in the output data path. The default value is "Amphion/data"
"processed_dir": "data",
...
},
```

### Run

Run the `run.sh` as the preproces stage (set `--stage 1`):

```bash
sh egs/tts/VALLE/run.sh --stage 1
```

> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.


## 3. Training

### Configuration

We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on your GPU machines.

```
"train": {
"batch_size": 4,
}
```

### Run

Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.

Specifically, VALL-E need to train a autoregressive (AR) model and then a non-autoregressive (NAR) model. So, you can set `--model_train_stage 1` to train AR model, and set `--model_train_stage 2` to train NAR model, where `--ar_model_ckpt_dir` should be set as the ckeckpoint path to the trained AR model.


Train a AR moel, just run:

```bash
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName]
```

Train a NAR model, just run:
```bash
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName]
```
<!-- > **NOTE:** To train a NAR model, `--checkpoint_path` should be set as the ckeckpoint path to the trained AR model. -->

> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.


## 4. Inference

### Configuration

For inference, you need to specify the following configurations when running `run.sh`:



| Parameters | Description | Example |
| --------------------- | -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `--infer_expt_dir` | The experimental directory of NAR model which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
| `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
| `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
| `--infer_text_prompt` | The text prompt for inference. | The text prompt should be aligned with the audio prompt. |
| `--infer_audio_prompt` | The audio prompt for inference. | The audio prompt should be aligned with text prompt.|
| `--test_list_file` | The test list file used for batch inference. | The format of test list file is `text\|text_prompt\|audio_prompt`.|


### Run
For example, if you want to generate a single clip of speech, just run:

```bash
sh egs/tts/VALLE/run.sh --stage 3 --gpu "0" \
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
--infer_mode "single" \
--infer_text "This is a clip of generated speech with the given text from a TTS model." \
--infer_text_prompt "But even the unsuccessful dramatist has his moments." \
--iner_audio_prompt egs/tts/VALLE/prompt_examples/7176_92135_000004_000000.wav
```


We will release a pre-trained VALL-E. So you can download the pre-trained model and generate speech following the above inference instruction.

```bibtex
@article{wang2023neural,
title={Neural codec language models are zero-shot text to speech synthesizers},
author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others},
journal={arXiv preprint arXiv:2301.02111},
year={2023}
}
```
Loading