Skip to content

Commit

Permalink
fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code
Browse files Browse the repository at this point in the history
Fixes #2447, #2574
  • Loading branch information
eginhard committed Nov 15, 2023
1 parent aa0fbdf commit ddbaecd
Showing 1 changed file with 69 additions and 110 deletions.
179 changes: 69 additions & 110 deletions notebooks/ExtractTTSpectrogram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,27 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import torch\n",
"import importlib\n",
"import os\n",
"import pickle\n",
"\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from torch.utils.data import DataLoader\n",
"import soundfile as sf\n",
"import pickle\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"\n",
"%matplotlib inline\n",
"\n",
Expand All @@ -49,11 +53,9 @@
" file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
" return file_name, wavq_path, mel_path"
]
},
{
Expand All @@ -65,25 +67,25 @@
"# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n",
"\n",
"QUANTIZED_WAV = False\n",
"QUANTIZE_BIT = None\n",
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n",
"# Check CUDA availability\n",
"use_cuda = torch.cuda.is_available()\n",
"print(\" > CUDA enabled: \", use_cuda)\n",
"\n",
"# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
"print(C['r'])"
"ap = AudioProcessor(**C.audio)"
]
},
{
Expand All @@ -92,12 +94,10 @@
"metadata": {},
"outputs": [],
"source": [
"# If the vocabulary was passed, replace the default\n",
"if 'characters' in C and C['characters']:\n",
" symbols, phonemes = make_symbols(**C.characters)\n",
"# Initialize the tokenizer\n",
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
"\n",
"# Load the model\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"# TODO: multiple speakers\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
Expand All @@ -109,42 +109,21 @@
"metadata": {},
"outputs": [],
"source": [
"# Load the preprocessor based on the dataset\n",
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
"# Load data instances\n",
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"meta_data = meta_data_train + meta_data_eval\n",
"\n",
"dataset = TTSDataset(\n",
" C,\n",
" C.text_cleaner,\n",
" False,\n",
" ap,\n",
" meta_data,\n",
" characters=C.get('characters', None),\n",
" use_phonemes=C.use_phonemes,\n",
" phoneme_cache_path=C.phoneme_cache_path,\n",
" enable_eos_bos=C.enable_eos_bos_chars,\n",
" outputs_per_step=C[\"r\"],\n",
" compute_linear_spec=False,\n",
" ap=ap,\n",
" samples=meta_data,\n",
" tokenizer=tokenizer,\n",
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Create log file\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"log_file = open(log_file_path, \"w\")"
")"
]
},
{
Expand All @@ -160,26 +139,33 @@
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n",
"with torch.no_grad():\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n",
" # setup input data\n",
" text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
"\n",
" # dispatch data to GPU\n",
" if use_cuda:\n",
" text_input = text_input.cuda()\n",
" text_lengths = text_lengths.cuda()\n",
" mel_input = mel_input.cuda()\n",
" mel_lengths = mel_lengths.cuda()\n",
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" data[\"mel\"] = data[\"mel\"].cuda()\n",
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n",
" mask = sequence_mask(text_lengths)\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n",
" # compute loss\n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
"\n",
Expand All @@ -193,28 +179,27 @@
" postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = alignments.detach().cpu().numpy()\n",
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n",
" if not DRY_RUN:\n",
" for idx in range(text_input.shape[0]):\n",
" wav_file_path = item_idx[idx]\n",
" for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
"\n",
" # quantize and save wav\n",
" if QUANTIZED_WAV:\n",
" wavq = ap.quantize(wav)\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
" mel = postnet_outputs[idx]\n",
" mel_length = mel_lengths[idx]\n",
" mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n",
"\n",
" metadata.append([wav_file_path, mel_path])\n",
"\n",
" except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n",
Expand All @@ -224,35 +209,20 @@
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n",
"# Close the log file\n",
"log_file.close()\n",
"\n",
"# For wavernn\n",
"if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n",
"# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
" for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n",
"# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -267,7 +237,7 @@
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
]
},
{
Expand All @@ -276,10 +246,9 @@
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
"mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n",
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
]
Expand All @@ -291,7 +260,7 @@
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
]
},
Expand Down Expand Up @@ -324,10 +293,9 @@
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
Expand All @@ -339,10 +307,9 @@
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
Expand All @@ -354,21 +321,13 @@
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
"mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit ddbaecd

Please sign in to comment.