diff --git a/examples/tts/conf/fastpitch_align_44100_adapter.yaml b/examples/tts/conf/fastpitch_align_44100_adapter.yaml index bac6a64b06e9..b2957b057d28 100644 --- a/examples/tts/conf/fastpitch_align_44100_adapter.yaml +++ b/examples/tts/conf/fastpitch_align_44100_adapter.yaml @@ -208,7 +208,7 @@ model: dropatt: 0.1 dropemb: 0.0 d_embed: ${model.symbols_embedding_dim} - condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ] + condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ] output_fft: _target_: nemo.collections.tts.modules.transformer.FFTransformerDecoder @@ -221,12 +221,12 @@ model: dropout: 0.1 dropatt: 0.1 dropemb: 0.0 - condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ] + condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ] alignment_module: _target_: nemo.collections.tts.modules.aligner.AlignmentEncoder n_text_channels: ${model.symbols_embedding_dim} - condition_types: [ "add" ] # options: [ "add", "cat" ] + condition_types: [ "add" ] # options: [ "add", "concat" ] duration_predictor: _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor @@ -235,7 +235,7 @@ model: filter_size: 256 dropout: 0.1 n_layers: 2 - condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ] + condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ] pitch_predictor: _target_: nemo.collections.tts.modules.fastpitch.TemporalPredictor @@ -244,10 +244,11 @@ model: filter_size: 256 dropout: 0.1 n_layers: 2 - condition_types: [ "add", "layernorm" ] # options: [ "add", "cat", "layernorm" ] + condition_types: [ "add", "layernorm" ] # options: [ "add", "concat", "layernorm" ] speaker_encoder: _target_: nemo.collections.tts.modules.submodules.SpeakerEncoder + precomputed_embedding_dim: null lookup_module: _target_: nemo.collections.tts.modules.submodules.SpeakerLookupTable n_speakers: ??? diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 5502e69a3111..28185c8f8622 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -183,6 +183,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if self.fastpitch.speaker_emb is not None: self.export_config["num_speakers"] = cfg.n_speakers + # Adapter modules setup (from FastPitchAdapterModelMixin) + self.setup_adapters() + def _get_default_text_tokenizer_conf(self): text_tokenizer: TextTokenizerConfig = TextTokenizerConfig() return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) diff --git a/nemo/collections/tts/modules/fastpitch.py b/nemo/collections/tts/modules/fastpitch.py index 5f2227a999db..77dff7bc85ed 100644 --- a/nemo/collections/tts/modules/fastpitch.py +++ b/nemo/collections/tts/modules/fastpitch.py @@ -177,7 +177,6 @@ def __init__( self.learn_alignment = aligner is not None self.use_duration_predictor = True self.binarize = False - # TODO: combine self.speaker_emb with self.speaker_encoder # cfg: remove `n_speakers`, create `speaker_encoder.lookup_module` # state_dict: move `speaker_emb.weight` to `speaker_encoder.lookup_module.table.weight` @@ -244,10 +243,10 @@ def output_types(self): "energy_tgt": NeuralType(('B', 'T_audio'), RegressionValuesType()), } - def get_speaker_embedding(self, speaker, reference_spec, reference_spec_lens): + def get_speaker_embedding(self, batch_size, speaker, reference_spec, reference_spec_lens): """spk_emb: Bx1xD""" if self.speaker_encoder is not None: - spk_emb = self.speaker_encoder(speaker, reference_spec, reference_spec_lens).unsqueeze(1) + spk_emb = self.speaker_encoder(batch_size, speaker, reference_spec, reference_spec_lens).unsqueeze(1) elif self.speaker_emb is not None: if speaker is None: raise ValueError('Please give speaker id to get lookup speaker embedding.') @@ -281,7 +280,10 @@ def forward( # Calculate speaker embedding spk_emb = self.get_speaker_embedding( - speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens, + batch_size=text.shape[0], + speaker=speaker, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_lens, ) # Input FFT @@ -379,10 +381,12 @@ def infer( reference_spec=None, reference_spec_lens=None, ): - # Calculate speaker embedding spk_emb = self.get_speaker_embedding( - speaker=speaker, reference_spec=reference_spec, reference_spec_lens=reference_spec_lens, + batch_size=text.shape[0], + speaker=speaker, + reference_spec=reference_spec, + reference_spec_lens=reference_spec_lens, ) # Input FFT diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index dbf26f1ceeee..6efccf18eeea 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -709,18 +709,29 @@ class SpeakerEncoder represents speakers representation. This module can combine GST (global style token) based speaker embeddings and lookup table speaker embeddings. """ - def __init__(self, lookup_module=None, gst_module=None): + def __init__(self, lookup_module=None, gst_module=None, precomputed_embedding_dim=None): """ lookup_module: Torch module to get lookup based speaker embedding gst_module: Neural module to get GST based speaker embedding + precomputed_embedding_dim: Give precomputed speaker embedding dimension to use precompute speaker embedding """ super(SpeakerEncoder, self).__init__() + + # Multi-speaker embedding self.lookup_module = lookup_module + + # Reference speaker embedding self.gst_module = gst_module + if precomputed_embedding_dim is not None: + self.precomputed_emb = torch.nn.Parameter(torch.empty(precomputed_embedding_dim)) + else: + self.precomputed_emb = None + @property def input_types(self): return { + "batch_size": NeuralType(optional=True), "speaker": NeuralType(('B'), Index(), optional=True), "reference_spec": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType(), optional=True), "reference_spec_lens": NeuralType(('B'), LengthsType(), optional=True), @@ -732,9 +743,16 @@ def output_types(self): "embs": NeuralType(('B', 'D'), EncodedRepresentation()), } - def forward(self, speaker=None, reference_spec=None, reference_spec_lens=None): + def overwrite_precomputed_emb(self, emb): + self.precomputed_emb = torch.nn.Parameter(emb) + + def forward(self, batch_size=None, speaker=None, reference_spec=None, reference_spec_lens=None): embs = None + # Get Precomputed speaker embedding + if self.precomputed_emb is not None: + return self.precomputed_emb.unsqueeze(0).repeat(batch_size, 1) + # Get Lookup table speaker embedding if self.lookup_module is not None and speaker is not None: embs = self.lookup_module(speaker) diff --git a/scripts/dataset_processing/tts/extract_sup_data.py b/scripts/dataset_processing/tts/extract_sup_data.py index 57fa220a733c..9a5dcc223444 100644 --- a/scripts/dataset_processing/tts/extract_sup_data.py +++ b/scripts/dataset_processing/tts/extract_sup_data.py @@ -31,7 +31,7 @@ def get_pitch_stats(pitch_list): def preprocess_ds_for_fastpitch_align(dataloader): pitch_list = [] for batch in tqdm(dataloader, total=len(dataloader)): - audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths = batch + audios, audio_lengths, tokens, tokens_lengths, align_prior_matrices, pitches, pitches_lengths, *_ = batch pitch = pitches.squeeze(0) pitch_list.append(pitch[pitch != 0]) diff --git a/scripts/dataset_processing/tts/resynthesize_dataset.py b/scripts/dataset_processing/tts/resynthesize_dataset.py index cacd41e93109..652fde299572 100644 --- a/scripts/dataset_processing/tts/resynthesize_dataset.py +++ b/scripts/dataset_processing/tts/resynthesize_dataset.py @@ -117,6 +117,15 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: batch = to_device_recursive(batch, self.device) mels, mel_lens = self.model.preprocessor(input_signal=batch["audio"], length=batch["audio_lens"]) + + reference_audio = batch.get("reference_audio", None) + reference_audio_len = batch.get("reference_audio_lens", None) + reference_spec, reference_spec_len = None, None + if reference_audio is not None: + reference_spec, reference_spec_len = self.model.preprocessor( + input_signal=reference_audio, length=reference_audio_len + ) + outputs_tuple = self.model.forward( text=batch["text"], durs=None, @@ -127,6 +136,8 @@ def resynthesize_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: attn_prior=batch.get("attn_prior"), mel_lens=mel_lens, input_lens=batch["text_lens"], + reference_spec=reference_spec, + reference_spec_lens=reference_spec_len, ) names = self.model.fastpitch.output_types.keys() return {"spec": mels, "mel_lens": mel_lens, **dict(zip(names, outputs_tuple))} diff --git a/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb b/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb new file mode 100644 index 000000000000..fa1b1bdc90c8 --- /dev/null +++ b/tutorials/tts/FastPitch_Adapter_Finetuning.ipynb @@ -0,0 +1,827 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ea49c0e5", + "metadata": {}, + "source": [ + "# FastPitch Adapter Finetuning\n", + "\n", + "This notebook is designed to provide a guide on how to run FastPitch Adapter Finetuning Pipeline. It contains the following sections:\n", + "1. **Fine-tune FastPitch on adaptation data**: fine-tune pre-trained multi-speaker FastPitch for a new speaker\n", + "* Dataset Preparation: download dataset and extract manifest files. (duration more than 15 mins)\n", + "* Preprocessing: add absolute audio paths in manifest and extract Supplementary Data.\n", + "* **Model Setting: transform pre-trained checkpoint to adapter-compatible checkpoint and precompute speaker embedding**\n", + "* Training: fine-tune frozen multispeaker FastPitch with trainable adapters.\n", + "2. **Fine-tune HiFiGAN on adaptation data**: fine-tune a vocoder for the fine-tuned multi-speaker FastPitch\n", + "* Dataset Preparation: extract mel-spectrograms from fine-tuned FastPitch.\n", + "* Training: fine-tune HiFiGAN with fine-tuned adaptation data.\n", + "3. **Inference**: generate speech from adpated FastPitch\n", + "* Load Model: load pre-trained multi-speaker FastPitch with **fine-tuned adapters**.\n", + "* Output Audio: generate audio files." + ] + }, + { + "cell_type": "markdown", + "id": "37259555", + "metadata": {}, + "source": [ + "# License\n", + "\n", + "> Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", + "> \n", + "> Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "> you may not use this file except in compliance with the License.\n", + "> You may obtain a copy of the License at\n", + "> \n", + "> http://www.apache.org/licenses/LICENSE-2.0\n", + "> \n", + "> Unless required by applicable law or agreed to in writing, software\n", + "> distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "> See the License for the specific language governing permissions and\n", + "> limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d61cbea5", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies# .\n", + "\"\"\"\n", + "# # If you're using Colab and not running locally, uncomment and run this cell.\n", + "# BRANCH = 'main'\n", + "# !apt-get install sox libsndfile1 ffmpeg\n", + "# !pip install wget unidecode pynini==2.1.4 scipy==1.7.3\n", + "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "\n", + "# # Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", + "# # comment out the below lines and set `code_dir` to your local path.\n", + "code_dir = 'NeMoTTS' \n", + "!git clone https://github.com/NVIDIA/NeMo.git {code_dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fef9aba9", + "metadata": {}, + "outputs": [], + "source": [ + "!wandb login #PASTE_WANDB_APIKEY_HERE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49bc38ab", + "metadata": {}, + "outputs": [], + "source": [ + "# .nemo files for your pre-trained FastPitch and HiFiGAN\n", + "pretrained_fastpitch_checkpoint = \"\"\n", + "finetuned_hifigan_on_multispeaker_checkpoint = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9459f9dc", + "metadata": {}, + "outputs": [], + "source": [ + "sample_rate = 44100\n", + "# Store all manifest and audios\n", + "data_dir = 'NeMoTTS_dataset'\n", + "# Store all supplementary files\n", + "supp_dir = \"NeMoTTS_sup_data\"\n", + "# Store all training logs\n", + "logs_dir = \"NeMoTTS_logs\"\n", + "# Store all mel-spectrograms for vocoder training\n", + "mels_dir = \"NeMoTTS_mels\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb26f54d", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import shutil\n", + "import nemo\n", + "import torch\n", + "import numpy as np\n", + "\n", + "from pathlib import Path\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12b28329", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(code_dir, exist_ok=True)\n", + "code_dir = os.path.abspath(code_dir)\n", + "os.makedirs(data_dir, exist_ok=True)\n", + "data_dir = os.path.abspath(data_dir)\n", + "os.makedirs(supp_dir, exist_ok=True)\n", + "supp_dir = os.path.abspath(supp_dir)\n", + "os.makedirs(logs_dir, exist_ok=True)\n", + "logs_dir = os.path.abspath(logs_dir)\n", + "os.makedirs(mels_dir, exist_ok=True)\n", + "mels_dir = os.path.abspath(mels_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "30996769", + "metadata": {}, + "source": [ + "# 1. Fine-tune FastPitch on adaptation data" + ] + }, + { + "cell_type": "markdown", + "id": "2f5f5945", + "metadata": {}, + "source": [ + "## a. Data Preparation\n", + "For our tutorial, we use small part of VCTK dataset with a new target speaker (p267). Usually, the audios should have total duration more than 15 mintues." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8047f988", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {data_dir} && wget https://vctk-subset.s3.amazonaws.com/vctk_subset.tar.gz && tar zxf vctk_subset.tar.gz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8242769", + "metadata": {}, + "outputs": [], + "source": [ + "manidir = f\"{data_dir}/vctk_subset\"\n", + "!ls {manidir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79cf8539", + "metadata": {}, + "outputs": [], + "source": [ + "train_manifest = os.path.abspath(os.path.join(manidir, 'train.json'))\n", + "valid_manifest = os.path.abspath(os.path.join(manidir, 'dev.json'))" + ] + }, + { + "cell_type": "markdown", + "id": "35c3b97b", + "metadata": {}, + "source": [ + "## b. Preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "ba3a7c3a", + "metadata": {}, + "source": [ + "### Add absolute file path in manifest\n", + "We use absoluate path for audio_filepath to get the audio during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8bc485b5", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9cb8ef5", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = read_manifest(train_manifest)\n", + "for m in train_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", + "write_manifest(train_manifest, train_data)\n", + "\n", + "valid_data = read_manifest(valid_manifest)\n", + "for m in valid_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", + "write_manifest(valid_manifest, valid_data)" + ] + }, + { + "cell_type": "markdown", + "id": "f92054d5", + "metadata": {}, + "source": [ + "### Extract Supplementary Data\n", + "\n", + "As mentioned in the [FastPitch and MixerTTS training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MixerTTS_Training.ipynb) - To accelerate and stabilize our training, we also need to extract pitch for every audio, estimate pitch statistics (mean, std, min, and max). To do this, all we need to do is iterate over our data one time, via `extract_sup_data.py` script." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0adc618b", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", + " manifest_filepath={train_manifest} \\\n", + " sup_data_path={supp_dir} \\\n", + " dataset.sample_rate={sample_rate} \\\n", + " dataset.n_fft=2048 \\\n", + " dataset.win_length=2048 \\\n", + " dataset.hop_length=512" + ] + }, + { + "cell_type": "markdown", + "id": "96dd5fe1", + "metadata": {}, + "source": [ + "After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.\n", + "```bash\n", + "PITCH_MEAN=175.48513793945312, PITCH_STD=42.3786735534668\n", + "PITCH_MIN=65.4063949584961, PITCH_MAX=270.8517761230469\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23703c76", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", + " manifest_filepath={valid_manifest} \\\n", + " sup_data_path={supp_dir} \\\n", + " dataset.sample_rate={sample_rate} \\\n", + " dataset.n_fft=2048 \\\n", + " dataset.win_length=2048 \\\n", + " dataset.hop_length=512" + ] + }, + { + "cell_type": "markdown", + "id": "7c70e5db", + "metadata": {}, + "source": [ + "## c. Model Setting\n", + "### Transform pre-trained checkpoint to adapter-compatible checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "439f2f82", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.tts.models import FastPitchModel\n", + "from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer\n", + "from nemo.core import adapter_mixins\n", + "from omegaconf import DictConfig, OmegaConf, open_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30f865cb", + "metadata": {}, + "outputs": [], + "source": [ + "def update_model_config_to_support_adapter(config) -> DictConfig:\n", + " with open_dict(config):\n", + " enc_adapter_metadata = adapter_mixins.get_registered_adapter(config.input_fft._target_)\n", + " if enc_adapter_metadata is not None:\n", + " config.input_fft._target_ = enc_adapter_metadata.adapter_class_path\n", + "\n", + " dec_adapter_metadata = adapter_mixins.get_registered_adapter(config.output_fft._target_)\n", + " if dec_adapter_metadata is not None:\n", + " config.output_fft._target_ = dec_adapter_metadata.adapter_class_path\n", + "\n", + " pitch_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.pitch_predictor._target_)\n", + " if pitch_predictor_adapter_metadata is not None:\n", + " config.pitch_predictor._target_ = pitch_predictor_adapter_metadata.adapter_class_path\n", + "\n", + " duration_predictor_adapter_metadata = adapter_mixins.get_registered_adapter(config.duration_predictor._target_)\n", + " if duration_predictor_adapter_metadata is not None:\n", + " config.duration_predictor._target_ = duration_predictor_adapter_metadata.adapter_class_path\n", + "\n", + " aligner_adapter_metadata = adapter_mixins.get_registered_adapter(config.alignment_module._target_)\n", + " if aligner_adapter_metadata is not None:\n", + " config.alignment_module._target_ = aligner_adapter_metadata.adapter_class_path\n", + "\n", + " return config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e92910b5", + "metadata": {}, + "outputs": [], + "source": [ + "spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint).eval().cuda()\n", + "spec_model.cfg = update_model_config_to_support_adapter(spec_model.cfg)" + ] + }, + { + "cell_type": "markdown", + "id": "7f03219f", + "metadata": {}, + "source": [ + "### Precompute Speaker Embedding\n", + "Get all GST speaker embeddings from training data, take average, and save as `precomputed_emb` in FastPitch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2a35241", + "metadata": {}, + "outputs": [], + "source": [ + "wave_model = WaveformFeaturizer(sample_rate=sample_rate)\n", + "train_data = read_manifest(train_manifest)\n", + "\n", + "spk_embs = [] \n", + "for data in train_data:\n", + " with torch.no_grad():\n", + " audio = wave_model.process(data['audio_filepath'])\n", + " audio_length = torch.tensor(audio.shape[0]).long()\n", + " audio = audio.unsqueeze(0).to(device=spec_model.device)\n", + " audio_length = audio_length.unsqueeze(0).to(device=spec_model.device)\n", + " spec_ref, spec_ref_lens = spec_model.preprocessor(input_signal=audio, length=audio_length)\n", + " spk_emb = spec_model.fastpitch.get_speaker_embedding(batch_size=spec_ref.shape[0],\n", + " speaker=None,\n", + " reference_spec=spec_ref,\n", + " reference_spec_lens=spec_ref_lens)\n", + "\n", + " spk_embs.append(spk_emb.squeeze().cpu())\n", + "\n", + "spk_embs = torch.stack(spk_embs, dim=0)\n", + "spk_emb = torch.mean(spk_embs, dim=0)\n", + "spk_emb_dim = spk_emb.shape[0]\n", + "\n", + "with open_dict(spec_model.cfg):\n", + " spec_model.cfg.speaker_encoder.precomputed_embedding_dim = spec_model.cfg.symbols_embedding_dim\n", + "\n", + "spec_model.fastpitch.speaker_encoder.overwrite_precomputed_emb(spk_emb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fa1b309", + "metadata": {}, + "outputs": [], + "source": [ + "spec_model.save_to('Pretrained-FastPitch.nemo')\n", + "shutil.copyfile(finetuned_hifigan_on_multispeaker_checkpoint, \"Pretrained-HifiGan.nemo\")\n", + "pretrained_fastpitch_checkpoint = os.path.abspath(\"Pretrained-FastPitch.nemo\")\n", + "finetuned_hifigan_on_multispeaker_checkpoint = os.path.abspath(\"Pretrained-HifiGan.nemo\")" + ] + }, + { + "cell_type": "markdown", + "id": "3b77e95f", + "metadata": {}, + "source": [ + "## d. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e8c3740", + "metadata": {}, + "outputs": [], + "source": [ + "phoneme_dict_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"cmudict-0.7b_nv22.10\"))\n", + "heteronyms_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"heteronyms-052722\"))\n", + "\n", + "# Copy and Paste the PITCH_MEAN and PITCH_STD from previous steps (train_manifest) to overide pitch_mean and pitch_std configs below.\n", + "PITCH_MEAN=175.48513793945312\n", + "PITCH_STD=42.3786735534668" + ] + }, + { + "cell_type": "markdown", + "id": "19bb6d8b", + "metadata": {}, + "source": [ + "### Important notes\n", + "* `+init_from_nemo_model`: initialize with a multi-speaker FastPitch checkpoint\n", + "* `model.speaker_encoder.precomputed_embedding_dim={spk_emb_dim}`: use precomputed speaker embedding\n", + "* `~model.speaker_encoder.lookup_module`: we use precomputed speaker embedding, so we remove the pre-trained looked-up speaker embedding\n", + "* `~model.speaker_encoder.gst_module`: we use precomputed speaker embedding, so we remove the pre-trained gst speaker embedding\n", + "* Other optional arguments based on your preference:\n", + " * batch_size\n", + " * exp_manager\n", + " * trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c8cbea2", + "metadata": {}, + "outputs": [], + "source": [ + "# Normally 100 epochs\n", + "!cd {code_dir} && python examples/tts/fastpitch_finetune_adapters.py \\\n", + "--config-name=fastpitch_align_44100_adapter.yaml \\\n", + "+init_from_nemo_model={pretrained_fastpitch_checkpoint} \\\n", + "train_dataset={train_manifest} \\\n", + "validation_datasets={valid_manifest} \\\n", + "sup_data_types=\"['align_prior_matrix', 'pitch']\" \\\n", + "sup_data_path={supp_dir} \\\n", + "pitch_mean={PITCH_MEAN} \\\n", + "pitch_std={PITCH_STD} \\\n", + "model.speaker_encoder.precomputed_embedding_dim={spk_emb_dim} \\\n", + "~model.speaker_encoder.lookup_module \\\n", + "~model.speaker_encoder.gst_module \\\n", + "model.train_ds.dataloader_params.batch_size=8 \\\n", + "model.validation_ds.dataloader_params.batch_size=8 \\\n", + "model.optim.name=adam \\\n", + "model.optim.lr=2e-4 \\\n", + "~model.optim.sched \\\n", + "exp_manager.exp_dir={logs_dir} \\\n", + "+exp_manager.create_wandb_logger=True \\\n", + "+exp_manager.wandb_logger_kwargs.name=\"tutorial-FastPitch-finetune-adaptation\" \\\n", + "+exp_manager.wandb_logger_kwargs.project=\"NeMo\" \\\n", + "+exp_manager.checkpoint_callback_params.save_top_k=-1 \\\n", + "trainer.max_epochs=10 \\\n", + "trainer.check_val_every_n_epoch=10 \\\n", + "trainer.log_every_n_steps=1 \\\n", + "trainer.devices=1 \\\n", + "trainer.strategy=ddp \\\n", + "trainer.precision=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe5c7b2f", + "metadata": {}, + "outputs": [], + "source": [ + "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/FastPitch.nemo\n", + "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/adapters.pt\n", + "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"FastPitch\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", + "finetuned_fastpitch_checkpoint = list(last_checkpoint_dir.glob('*.nemo'))[0]\n", + "finetuned_adapter_checkpoint = list(last_checkpoint_dir.glob('adapters.pt'))[0]\n", + "print(finetuned_fastpitch_checkpoint)\n", + "print(finetuned_adapter_checkpoint)" + ] + }, + { + "cell_type": "markdown", + "id": "75856d0e", + "metadata": {}, + "source": [ + "# 3. Fine-tune HiFiGAN on adaptation data" + ] + }, + { + "cell_type": "markdown", + "id": "3444698f", + "metadata": {}, + "source": [ + "## a. Dataset Preparation\n", + "Generate mel-spectrograms for HiFiGAN training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb2fd64d", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} \\\n", + "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", + "--model-path={finetuned_fastpitch_checkpoint} \\\n", + "--input-json-manifest={train_manifest} \\\n", + "--input-sup-data-path={supp_dir} \\\n", + "--output-folder={mels_dir} \\\n", + "--device=\"cuda:0\" \\\n", + "--batch-size=1 \\\n", + "--num-workers=1 \\\n", + "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", + "--model-path={finetuned_fastpitch_checkpoint} \\\n", + "--input-json-manifest={valid_manifest} \\\n", + "--input-sup-data-path={supp_dir} \\\n", + "--output-folder={mels_dir} \\\n", + "--device=\"cuda:0\" \\\n", + "--batch-size=1 \\\n", + "--num-workers=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da69cb66", + "metadata": {}, + "outputs": [], + "source": [ + "train_manifest_mel = f\"{mels_dir}/train_mel.json\"\n", + "valid_manifest_mel = f\"{mels_dir}/dev_mel.json\"" + ] + }, + { + "cell_type": "markdown", + "id": "fa2cbb02", + "metadata": {}, + "source": [ + "## b. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffdce5d5", + "metadata": {}, + "outputs": [], + "source": [ + "# Normally 500 epochs\n", + "!cd {code_dir} && python examples/tts/hifigan_finetune.py \\\n", + "--config-name=hifigan_44100.yaml \\\n", + "train_dataset={train_manifest_mel} \\\n", + "validation_datasets={valid_manifest_mel} \\\n", + "+init_from_nemo_model={finetuned_hifigan_on_multispeaker_checkpoint} \\\n", + "model.train_ds.dataloader_params.batch_size=32 \\\n", + "model.optim.lr=0.0001 \\\n", + "model/train_ds=train_ds_finetune \\\n", + "model/validation_ds=val_ds_finetune \\\n", + "+trainer.max_epochs=5 \\\n", + "trainer.check_val_every_n_epoch=5 \\\n", + "trainer.devices=-1 \\\n", + "trainer.strategy='ddp' \\\n", + "trainer.precision=16 \\\n", + "exp_manager.exp_dir={logs_dir} \\\n", + "exp_manager.create_wandb_logger=True \\\n", + "exp_manager.wandb_logger_kwargs.name=\"tutorial-HiFiGAN-finetune-multispeaker\" \\\n", + "exp_manager.wandb_logger_kwargs.project=\"NeMo\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e6376cf", + "metadata": {}, + "outputs": [], + "source": [ + "# e.g. NeMoTTS_logs/HifiGan/Y-M-D_H-M-S/checkpoints/HifiGan.nemo\n", + "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"HifiGan\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", + "finetuned_hifigan_on_adaptation_checkpoint = list(last_checkpoint_dir.glob('*.nemo'))[0]\n", + "finetuned_hifigan_on_adaptation_checkpoint" + ] + }, + { + "cell_type": "markdown", + "id": "e5076e51", + "metadata": {}, + "source": [ + "# 4. Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52358549", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.tts.models import HifiGanModel\n", + "import IPython.display as ipd\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "9e96ee13", + "metadata": {}, + "source": [ + "## a. Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cb5d524", + "metadata": {}, + "outputs": [], + "source": [ + "wave_model = WaveformFeaturizer(sample_rate=sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32dbd30c", + "metadata": {}, + "outputs": [], + "source": [ + "# Load from pretrained FastPitch and finetuned adapter\n", + "# spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint)\n", + "# spec_model.load_adapters(finetuned_adapter_checkpoint)\n", + "\n", + "# Load from finetuned FastPitch\n", + "spec_model = FastPitchModel.restore_from(finetuned_fastpitch_checkpoint)\n", + "spec_model = spec_model.eval().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74a7ad03", + "metadata": {}, + "outputs": [], + "source": [ + "# HiFiGAN\n", + "vocoder_model = HifiGanModel.restore_from(finetuned_hifigan_on_adaptation_checkpoint).eval().cuda()" + ] + }, + { + "cell_type": "markdown", + "id": "4f882975", + "metadata": {}, + "source": [ + "## b. Output Audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2178a8ef", + "metadata": {}, + "outputs": [], + "source": [ + "def gt_spectrogram(audio_path, wave_model, spec_gen_model):\n", + " features = wave_model.process(audio_path, trim=False)\n", + " audio, audio_length = features, torch.tensor(features.shape[0]).long()\n", + " audio = audio.unsqueeze(0).to(device=spec_gen_model.device)\n", + " audio_length = audio_length.unsqueeze(0).to(device=spec_gen_model.device)\n", + " with torch.no_grad():\n", + " spectrogram, spec_len = spec_gen_model.preprocessor(input_signal=audio, length=audio_length)\n", + " return spectrogram, spec_len\n", + "\n", + "def gen_spectrogram(text, spec_gen_model, reference_spec, reference_spec_lens):\n", + " parsed = spec_gen_model.parse(text)\n", + " with torch.no_grad(): \n", + " spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, \n", + " reference_spec=reference_spec, \n", + " reference_spec_lens=reference_spec_lens)\n", + "\n", + " return spectrogram\n", + " \n", + "def synth_audio(vocoder_model, spectrogram): \n", + " with torch.no_grad(): \n", + " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n", + " if isinstance(audio, torch.Tensor):\n", + " audio = audio.to('cpu').numpy()\n", + " return audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "766154e3", + "metadata": {}, + "outputs": [], + "source": [ + "# Reference Audio\n", + "with open(train_manifest, \"r\") as f:\n", + " for i, line in enumerate(f):\n", + " reference_record = json.loads(line)\n", + " break\n", + " \n", + "# Validatation Audio\n", + "num_val = 3\n", + "val_records = []\n", + "with open(valid_manifest, \"r\") as f:\n", + " for i, line in enumerate(f):\n", + " val_records.append(json.loads(line))\n", + " if len(val_records) >= num_val:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfa71ca6", + "metadata": {}, + "outputs": [], + "source": [ + "for i, val_record in enumerate(val_records):\n", + " reference_spec, reference_spec_lens = gt_spectrogram(reference_record['audio_filepath'], wave_model, spec_model)\n", + " reference_spec = reference_spec.to(spec_model.device)\n", + " spec_pred = gen_spectrogram(val_record['text'], spec_model,\n", + " reference_spec=reference_spec, \n", + " reference_spec_lens=reference_spec_lens)\n", + "\n", + " audio_gen = synth_audio(vocoder_model, spec_pred)\n", + " \n", + " audio_ref = ipd.Audio(reference_record['audio_filepath'], rate=sample_rate)\n", + " audio_gt = ipd.Audio(val_record['audio_filepath'], rate=sample_rate)\n", + " audio_gen = ipd.Audio(audio_gen, rate=sample_rate)\n", + " \n", + " print(\"------\")\n", + " print(f\"Text: {val_record['text']}\")\n", + " print('Reference Audio')\n", + " ipd.display(audio_ref)\n", + " print('Ground Truth Audio')\n", + " ipd.display(audio_gt)\n", + " print('Synthesized Audio')\n", + " ipd.display(audio_gen)\n", + " plt.imshow(spec_pred[0].to('cpu').numpy(), origin=\"lower\", aspect=\"auto\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51d9d176", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Pretraind FastPitch: {pretrained_fastpitch_checkpoint}\")\n", + "print(f\"Finetuned Adapter: {finetuned_adapter_checkpoint}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6180a7d2", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Finetuned FastPitch: {finetuned_fastpitch_checkpoint}\")\n", + "print(f\"Finetuned HiFi-Gan: {finetuned_hifigan_on_adaptation_checkpoint}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b33263b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/tts/FastPitch_MultiSpeaker_Pretraining.ipynb b/tutorials/tts/FastPitch_MultiSpeaker_Pretraining.ipynb new file mode 100644 index 000000000000..defd0272d89d --- /dev/null +++ b/tutorials/tts/FastPitch_MultiSpeaker_Pretraining.ipynb @@ -0,0 +1,735 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "afd8cdc9", + "metadata": {}, + "source": [ + "# FastPitch MultiSpeaker Pretraining\n", + "\n", + "This notebook is designed to provide a guide on how to run FastPitch MultiSpeaker Pretraining Pipeline. It contains the following sections:\n", + "1. **Pre-train FastPitch on multi-speaker data**: pre-train a multi-speaker FastPitch\n", + "* Dataset Preparation: download dataset and extract manifest files.\n", + "* Preprocessing: add absolute audio paths in manifest, calibrate speaker id to start from 0, and extract Supplementary Data.\n", + "* Training: pre-train multispeaker FastPitch\n", + "2. **Fine-tune HiFiGAN on multi-speaker data**: fine-tune a vocoder for the pre-trained multi-speaker FastPitch\n", + "* Dataset Preparation: extract mel-spectrograms from pre-trained FastPitch.\n", + "* Training: fine-tune HiFiGAN with pre-trained multi-speaker data.\n", + "3. **Inference**: generate speech from pre-trained multi-speaker FastPitch\n", + "* Load Model: load pre-trained multi-speaker FastPitch.\n", + "* Output Audio: generate audio files." + ] + }, + { + "cell_type": "markdown", + "id": "4fc9c6b9", + "metadata": {}, + "source": [ + "# License\n", + "> Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n", + "> \n", + "> Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "> you may not use this file except in compliance with the License.\n", + "> You may obtain a copy of the License at\n", + "> \n", + "> http://www.apache.org/licenses/LICENSE-2.0\n", + "> \n", + "> Unless required by applicable law or agreed to in writing, software\n", + "> distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "> WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "> See the License for the specific language governing permissions and\n", + "> limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b81f6c14", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "You can either run this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", + "Instructions for setting up Colab are as follows:\n", + "1. Open a new Python 3 notebook.\n", + "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", + "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", + "4. Run this cell to set up dependencies# .\n", + "\"\"\"\n", + "# BRANCH = 'main'\n", + "# # If you're using Colab and not running locally, uncomment and run this cell.\n", + "# !apt-get install sox libsndfile1 ffmpeg\n", + "# !pip install wget unidecode pynini==2.1.4 scipy==1.7.3\n", + "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "\n", + "# # Download local version of NeMo scripts. If you are running locally and want to use your own local NeMo code,\n", + "# # comment out the below lines and set `code_dir` to your local path.\n", + "code_dir = 'NeMoTTS' \n", + "!git clone https://github.com/NVIDIA/NeMo.git {code_dir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2f1e3ac", + "metadata": {}, + "outputs": [], + "source": [ + "!wandb login #PASTE_WANDB_APIKEY_HERE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1acd141d", + "metadata": {}, + "outputs": [], + "source": [ + "sample_rate = 44100\n", + "# Store all manifest and audios\n", + "data_dir = 'NeMoTTS_dataset'\n", + "# Store all supplementary files\n", + "supp_dir = \"NeMoTTS_sup_data\"\n", + "# Store all training logs\n", + "logs_dir = \"NeMoTTS_logs\"\n", + "# Store all mel-spectrograms for vocoder training\n", + "mels_dir = \"NeMoTTS_mels\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b54c45e", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import nemo\n", + "import torch\n", + "import numpy as np\n", + "\n", + "from pathlib import Path\n", + "from tqdm import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a119994b", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(code_dir, exist_ok=True)\n", + "code_dir = os.path.abspath(code_dir)\n", + "os.makedirs(data_dir, exist_ok=True)\n", + "data_dir = os.path.abspath(data_dir)\n", + "os.makedirs(supp_dir, exist_ok=True)\n", + "supp_dir = os.path.abspath(supp_dir)\n", + "os.makedirs(logs_dir, exist_ok=True)\n", + "logs_dir = os.path.abspath(logs_dir)\n", + "os.makedirs(mels_dir, exist_ok=True)\n", + "mels_dir = os.path.abspath(mels_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "dbb3ac0e", + "metadata": {}, + "source": [ + "# 1. Pre-train FastPitch on multi-speaker data" + ] + }, + { + "cell_type": "markdown", + "id": "095a1fca", + "metadata": {}, + "source": [ + "## a. Dataset Preparation\n", + "For our tutorial, we use the subset of VCTK dataset with 5 speakers (p225-p229). The audios have 48 kHz sampling rate, we downsample to 44.1 kHz in this tutorial. \n", + "You can read more about dataset [here](https://datashare.ed.ac.uk/handle/10283/2950)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69b17b07", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {data_dir} && wget https://vctk-subset.s3.amazonaws.com/vctk_subset_multispeaker.tar.gz && tar zxf vctk_subset_multispeaker.tar.gz" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a65e7938", + "metadata": {}, + "outputs": [], + "source": [ + "manidir = f\"{data_dir}/vctk_subset_multispeaker\"\n", + "!ls {manidir}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08b27b92", + "metadata": {}, + "outputs": [], + "source": [ + "train_manifest = os.path.abspath(os.path.join(manidir, 'train.json'))\n", + "valid_manifest = os.path.abspath(os.path.join(manidir, 'dev.json'))" + ] + }, + { + "cell_type": "markdown", + "id": "7cbf24d6", + "metadata": {}, + "source": [ + "## b. Preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "cae8567d", + "metadata": {}, + "source": [ + "### Add absoluate audio path in manifest\n", + "We use absoluate path for `audio_filepath` to get the audio during training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71d2fe63", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc51398c", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = read_manifest(train_manifest)\n", + "for m in train_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", + "write_manifest(train_manifest, train_data)\n", + "\n", + "valid_data = read_manifest(valid_manifest)\n", + "for m in valid_data: m['audio_filepath'] = os.path.abspath(os.path.join(manidir, m['audio_filepath']))\n", + "write_manifest(valid_manifest, valid_data)" + ] + }, + { + "cell_type": "markdown", + "id": "678bb37c", + "metadata": {}, + "source": [ + "### Calibrate speaker id to start from 0\n", + "We use speaker id start from 0, so we can create a speaker look-up table with speaker size." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "594c6f2d", + "metadata": {}, + "outputs": [], + "source": [ + "train_data = read_manifest(train_manifest)\n", + "speaker2id = {s: _id for _id, s in enumerate(set([m['speaker'] for m in train_data]))}\n", + "for m in train_data: m['old_speaker'], m['speaker'] = m['speaker'], speaker2id[m['speaker']]\n", + "write_manifest(train_manifest, train_data)\n", + "\n", + "valid_data = read_manifest(valid_manifest)\n", + "for m in valid_data: m['old_speaker'], m['speaker'] = m['speaker'], speaker2id[m['speaker']]\n", + "write_manifest(valid_manifest, valid_data)" + ] + }, + { + "cell_type": "markdown", + "id": "15b6cc65", + "metadata": {}, + "source": [ + "### Extract Supplementary Data\n", + "\n", + "As mentioned in the [FastPitch and MixerTTS training tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tts/FastPitch_MixerTTS_Training.ipynb) - To accelerate and stabilize our training, we also need to extract pitch for every audio, estimate pitch statistics (mean, std, min, and max). To do this, all we need to do is iterate over our data one time, via `extract_sup_data.py` script." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3728ac9", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", + " manifest_filepath={train_manifest} \\\n", + " sup_data_path={supp_dir} \\\n", + " dataset.sample_rate={sample_rate} \\\n", + " dataset.n_fft=2048 \\\n", + " dataset.win_length=2048 \\\n", + " dataset.hop_length=512" + ] + }, + { + "cell_type": "markdown", + "id": "effd9182", + "metadata": {}, + "source": [ + "After running the above command line, you will observe a new folder NeMoTTS_sup_data/pitch and printouts of pitch statistics like below. Specify these values to the FastPitch training configurations. We will be there in the following section.\n", + "```bash\n", + "PITCH_MEAN=140.84278869628906, PITCH_STD=50.97673034667969\n", + "PITCH_MIN=65.4063949584961, PITCH_MAX=285.3046875\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37e54cd4", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} && python scripts/dataset_processing/tts/extract_sup_data.py \\\n", + " manifest_filepath={valid_manifest} \\\n", + " sup_data_path={supp_dir} \\\n", + " dataset.sample_rate={sample_rate} \\\n", + " dataset.n_fft=2048 \\\n", + " dataset.win_length=2048 \\\n", + " dataset.hop_length=512" + ] + }, + { + "cell_type": "markdown", + "id": "82d2c99d", + "metadata": {}, + "source": [ + "* If you want to compute pitch mean and std for each speaker, you can use the script `compute_speaker_stats.py`\n", + "```bash\n", + "!cd {code_dir} && python scripts/dataset_processing/tts/compute_speaker_stats.py \\\n", + " --manifest_path={train_manifest} \\\n", + " --sup_data_path={supp_dir} \\\n", + " --pitch_stats_path={data_dir}/pitch_stats.json\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "a7c8dfb6", + "metadata": {}, + "source": [ + "## c. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e378a792", + "metadata": {}, + "outputs": [], + "source": [ + "phoneme_dict_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"cmudict-0.7b_nv22.10\"))\n", + "heteronyms_path = os.path.abspath(os.path.join(code_dir, \"scripts\", \"tts_dataset_files\", \"heteronyms-052722\"))\n", + "\n", + "# Copy and Paste the PITCH_MEAN and PITCH_STD from previous steps (train_manifest) to overide pitch_mean and pitch_std configs below.\n", + "PITCH_MEAN=140.84278869628906\n", + "PITCH_STD=65.4063949584961" + ] + }, + { + "cell_type": "markdown", + "id": "a90ddfb3", + "metadata": {}, + "source": [ + "### Important notes\n", + "* `sup_data_types=\"['align_prior_matrix', 'pitch', 'speaker_id', 'reference_audio']\" `\n", + " * **speaker_id**: each data has an unique speaker index (start from 0) in the input.\n", + " * **reference_audio**: each data has a reference audio (from the same speaker) in the input.\n", + " \n", + "* `model.speaker_encoder.lookup_module.n_speakers`\n", + " * if use **model.speaker_encoder.lookup_module**, please give n_speakers to create the lookup table\n", + "\n", + "* `condition_types=\"['add', 'concat', layernorm']`\n", + " * use different operation type to condition module (e.g. input_fft/output_fft/duration_predictor/pitch_predictor/alignment_module)\n", + " * **add**: add conditions to module input\n", + " * **concat**: concat conditions to module input\n", + " * **layernorm**: scale and shift layernorm outputs based on conditions\n", + " \n", + "* Other default arguments in config:\n", + " * `model.speaker_encoder.lookup_module`: model creates lookup table to get speaker embedding from speaker id.\n", + " * `model.speaker_encoder.gst_module`: model creates global style token to extract speaker information from reference audio.\n", + "\n", + "* Other optional arguments based on your preference:\n", + " * batch_size\n", + " * max_duration\n", + " * min_duration\n", + " * exp_manager\n", + " * trainer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac22f3a8", + "metadata": {}, + "outputs": [], + "source": [ + "# Normally 200 epochs\n", + "!(cd {code_dir} && python examples/tts/fastpitch.py \\\n", + "--config-name=fastpitch_align_44100_adapter.yaml \\\n", + "+init_from_pretrained_model=\"tts_en_fastpitch\" \\\n", + "train_dataset={train_manifest} \\\n", + "validation_datasets={valid_manifest} \\\n", + "sup_data_types=\"['align_prior_matrix', 'pitch', 'speaker_id', 'reference_audio']\" \\\n", + "sup_data_path={supp_dir} \\\n", + "pitch_mean={PITCH_MEAN} \\\n", + "pitch_std={PITCH_STD} \\\n", + "phoneme_dict_path={phoneme_dict_path} \\\n", + "heteronyms_path={heteronyms_path} \\\n", + "model.speaker_encoder.lookup_module.n_speakers=5 \\\n", + "model.input_fft.condition_types=\"['add', 'layernorm']\" \\\n", + "model.output_fft.condition_types=\"['add', 'layernorm']\" \\\n", + "model.duration_predictor.condition_types=\"['add', 'layernorm']\" \\\n", + "model.pitch_predictor.condition_types=\"['add', 'layernorm']\" \\\n", + "model.alignment_module.condition_types=\"['add']\" \\\n", + "model.train_ds.dataloader_params.batch_size=8 \\\n", + "model.validation_ds.dataloader_params.batch_size=8 \\\n", + "model.train_ds.dataset.max_duration=20 \\\n", + "model.validation_ds.dataset.max_duration=20 \\\n", + "model.validation_ds.dataset.min_duration=0.1 \\\n", + "exp_manager.exp_dir={logs_dir} \\\n", + "+exp_manager.create_wandb_logger=True \\\n", + "+exp_manager.wandb_logger_kwargs.name=\"tutorial-FastPitch-pretrain-multispeaker\" \\\n", + "+exp_manager.wandb_logger_kwargs.project=\"NeMo\" \\\n", + "trainer.max_epochs=20 \\\n", + "trainer.check_val_every_n_epoch=20 \\\n", + "trainer.log_every_n_steps=1 \\\n", + "trainer.devices=-1 \\\n", + "trainer.strategy=ddp \\\n", + "trainer.precision=32 \\\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6fc98a5", + "metadata": {}, + "outputs": [], + "source": [ + "# e.g. NeMoTTS_logs/FastPitch/Y-M-D_H-M-S/checkpoints/FastPitch.nemo\n", + "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"FastPitch\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", + "pretrained_fastpitch_checkpoint = os.path.abspath(list(last_checkpoint_dir.glob('*.nemo'))[0])\n", + "print(pretrained_fastpitch_checkpoint)" + ] + }, + { + "cell_type": "markdown", + "id": "b175f755", + "metadata": {}, + "source": [ + "# 2. Fine-tune HiFiGAN on multi-speaker data" + ] + }, + { + "cell_type": "markdown", + "id": "5749a0b8", + "metadata": {}, + "source": [ + "## a. Dataset Preparation\n", + "Generate mel-spectrograms for HiFiGAN training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d77bda9", + "metadata": {}, + "outputs": [], + "source": [ + "!cd {code_dir} \\\n", + "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", + "--model-path={pretrained_fastpitch_checkpoint} \\\n", + "--input-json-manifest={train_manifest} \\\n", + "--input-sup-data-path={supp_dir} \\\n", + "--output-folder={mels_dir} \\\n", + "--device=\"cuda:0\" \\\n", + "--batch-size=1 \\\n", + "--num-workers=1 \\\n", + "&& python scripts/dataset_processing/tts/resynthesize_dataset.py \\\n", + "--model-path={pretrained_fastpitch_checkpoint} \\\n", + "--input-json-manifest={valid_manifest} \\\n", + "--input-sup-data-path={supp_dir} \\\n", + "--output-folder={mels_dir} \\\n", + "--device=\"cuda:0\" \\\n", + "--batch-size=1 \\\n", + "--num-workers=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c9159a1", + "metadata": {}, + "outputs": [], + "source": [ + "train_manifest_mel = f\"{mels_dir}/train_mel.json\"\n", + "valid_manifest_mel = f\"{mels_dir}/dev_mel.json\"" + ] + }, + { + "cell_type": "markdown", + "id": "24653f24", + "metadata": {}, + "source": [ + "## b. Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fadc0410", + "metadata": {}, + "outputs": [], + "source": [ + "# Normally 100 epochs\n", + "!cd {code_dir} && python examples/tts/hifigan_finetune.py \\\n", + "--config-name=hifigan_44100.yaml \\\n", + "train_dataset={train_manifest_mel} \\\n", + "validation_datasets={valid_manifest_mel} \\\n", + "+init_from_pretrained_model=\"tts_en_hifitts_hifigan_ft_fastpitch\" \\\n", + "model.train_ds.dataloader_params.batch_size=32 \\\n", + "model.optim.lr=0.0001 \\\n", + "model/train_ds=train_ds_finetune \\\n", + "model/validation_ds=val_ds_finetune \\\n", + "+trainer.max_epochs=5 \\\n", + "trainer.check_val_every_n_epoch=5 \\\n", + "trainer.devices=1 \\\n", + "trainer.strategy='ddp' \\\n", + "trainer.precision=16 \\\n", + "exp_manager.exp_dir={logs_dir} \\\n", + "exp_manager.create_wandb_logger=True \\\n", + "exp_manager.wandb_logger_kwargs.name=\"tutorial-HiFiGAN-finetune-multispeaker\" \\\n", + "exp_manager.wandb_logger_kwargs.project=\"NeMo\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "864fe5ba", + "metadata": {}, + "outputs": [], + "source": [ + "# e.g. NeMoTTS_logs/HifiGan/Y-M-D_H-M-S/checkpoints/HifiGan.nemo\n", + "last_checkpoint_dir = sorted(list([i for i in (Path(logs_dir) / \"HifiGan\").iterdir() if i.is_dir()]))[-1] / \"checkpoints\"\n", + "finetuned_hifigan_on_multispeaker_checkpoint = os.path.abspath(list(last_checkpoint_dir.glob('*.nemo'))[0])\n", + "finetuned_hifigan_on_multispeaker_checkpoint" + ] + }, + { + "cell_type": "markdown", + "id": "e04540b6", + "metadata": {}, + "source": [ + "# 3. Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdf662f7", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer\n", + "from nemo.collections.tts.models import FastPitchModel\n", + "from nemo.collections.tts.models import HifiGanModel\n", + "from collections import defaultdict\n", + "import IPython.display as ipd\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "id": "270a3264", + "metadata": {}, + "source": [ + "## a. Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01315a66", + "metadata": {}, + "outputs": [], + "source": [ + "wave_model = WaveformFeaturizer(sample_rate=sample_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "536c8fdc", + "metadata": {}, + "outputs": [], + "source": [ + "# FastPitch\n", + "spec_model = FastPitchModel.restore_from(pretrained_fastpitch_checkpoint).eval().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2ace7c4", + "metadata": {}, + "outputs": [], + "source": [ + "# HiFiGAN\n", + "vocoder_model = HifiGanModel.restore_from(finetuned_hifigan_on_multispeaker_checkpoint).eval().cuda()" + ] + }, + { + "cell_type": "markdown", + "id": "cf4a42fa", + "metadata": {}, + "source": [ + "## b. Output Audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b376468", + "metadata": {}, + "outputs": [], + "source": [ + "def gt_spectrogram(audio_path, wave_model, spec_gen_model):\n", + " features = wave_model.process(audio_path, trim=False)\n", + " audio, audio_length = features, torch.tensor(features.shape[0]).long()\n", + " audio = audio.unsqueeze(0).to(device=spec_gen_model.device)\n", + " audio_length = audio_length.unsqueeze(0).to(device=spec_gen_model.device)\n", + " with torch.no_grad():\n", + " spectrogram, spec_len = spec_gen_model.preprocessor(input_signal=audio, length=audio_length)\n", + " return spectrogram, spec_len\n", + "\n", + "def gen_spectrogram(text, spec_gen_model, speaker, reference_spec, reference_spec_lens):\n", + " parsed = spec_gen_model.parse(text)\n", + " speaker = torch.tensor([speaker]).long().to(device=spec_gen_model.device)\n", + " with torch.no_grad(): \n", + " spectrogram = spec_gen_model.generate_spectrogram(tokens=parsed, \n", + " speaker=speaker, \n", + " reference_spec=reference_spec, \n", + " reference_spec_lens=reference_spec_lens)\n", + "\n", + " return spectrogram\n", + " \n", + "def synth_audio(vocoder_model, spectrogram): \n", + " with torch.no_grad(): \n", + " audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)\n", + " if isinstance(audio, torch.Tensor):\n", + " audio = audio.to('cpu').numpy()\n", + " return audio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f93f73a6", + "metadata": {}, + "outputs": [], + "source": [ + "# Reference Audio\n", + "reference_records = []\n", + "with open(train_manifest, \"r\") as f:\n", + " for i, line in enumerate(f):\n", + " reference_records.append(json.loads(line))\n", + "\n", + "speaker_to_index = defaultdict(list)\n", + "for i, d in enumerate(reference_records): speaker_to_index[d.get('speaker', None)].append(i)\n", + " \n", + "# Validatation Audio\n", + "num_val = 3\n", + "val_records = []\n", + "with open(valid_manifest, \"r\") as f:\n", + " for i, line in enumerate(f):\n", + " val_records.append(json.loads(line))\n", + " if len(val_records) >= num_val:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77590752", + "metadata": {}, + "outputs": [], + "source": [ + "for i, val_record in enumerate(val_records):\n", + " reference_record = reference_records[speaker_to_index[val_record['speaker']][0]]\n", + " reference_spec, reference_spec_lens = gt_spectrogram(reference_record['audio_filepath'], wave_model, spec_model)\n", + " reference_spec = reference_spec.to(spec_model.device)\n", + " spec_pred = gen_spectrogram(val_record['text'], \n", + " spec_model,\n", + " speaker=val_record['speaker'], \n", + " reference_spec=reference_spec, \n", + " reference_spec_lens=reference_spec_lens)\n", + "\n", + " audio_gen = synth_audio(vocoder_model, spec_pred)\n", + " \n", + " audio_ref = ipd.Audio(reference_record['audio_filepath'], rate=sample_rate)\n", + " audio_gt = ipd.Audio(val_record['audio_filepath'], rate=sample_rate)\n", + " audio_gen = ipd.Audio(audio_gen, rate=sample_rate)\n", + " \n", + " print(\"------\")\n", + " print(f\"Text: {val_record['text']}\")\n", + " print('Reference Audio')\n", + " ipd.display(audio_ref)\n", + " print('Ground Truth Audio')\n", + " ipd.display(audio_gt)\n", + " print('Synthesized Audio')\n", + " ipd.display(audio_gen)\n", + " plt.imshow(spec_pred[0].to('cpu').numpy(), origin=\"lower\", aspect=\"auto\")\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cd156e4", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"FastPitch checkpoint: {pretrained_fastpitch_checkpoint}\")\n", + "print(f\"HiFi-Gan checkpoint: {finetuned_hifigan_on_multispeaker_checkpoint}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}