Skip to content

Commit

Permalink
Merge pull request #2390 from coqui-ai/dev
Browse files Browse the repository at this point in the history
v0.12.0
  • Loading branch information
erogol authored Mar 13, 2023
2 parents d488b4f + c10f9a3 commit 9bb62c5
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 53 deletions.
18 changes: 8 additions & 10 deletions TTS/encoder/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,14 @@ def forward(self, x, l2_norm=False):
Shapes:
- x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
"""
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)

if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)
x.squeeze_(1)
# if you torch spec compute it otherwise use the mel spec computed by the AP
if self.use_torch_spec:
x = self.torch_spec(x)

if self.log_input:
x = (x + 1e-6).log()
x = self.instancenorm(x).unsqueeze(1)

x = self.conv1(x)
x = self.relu(x)
Expand Down
53 changes: 49 additions & 4 deletions TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from pathlib import Path
from threading import Lock
from typing import Union
from urllib.parse import parse_qs

from flask import Flask, render_template, request, send_file
from flask import Flask, render_template, render_template_string, request, send_file

from TTS.config import load_config
from TTS.utils.manage import ModelManager
Expand Down Expand Up @@ -187,15 +188,59 @@ def tts():
language_idx = request.args.get("language_id", "")
style_wav = request.args.get("style_wav", "")
style_wav = style_wav_uri_to_dict(style_wav)
print(" > Model input: {}".format(text))
print(" > Speaker Idx: {}".format(speaker_idx))
print(" > Language Idx: {}".format(language_idx))
print(f" > Model input: {text}")
print(f" > Speaker Idx: {speaker_idx}")
print(f" > Language Idx: {language_idx}")
wavs = synthesizer.tts(text, speaker_name=speaker_idx, language_name=language_idx, style_wav=style_wav)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")


# Basic MaryTTS compatibility layer


@app.route("/locales", methods=["GET"])
def mary_tts_api_locales():
"""MaryTTS-compatible /locales endpoint"""
# NOTE: We currently assume there is only one model active at the same time
if args.model_name is not None:
model_details = args.model_name.split("/")
else:
model_details = ["", "en", "", "default"]
return render_template_string("{{ locale }}\n", locale=model_details[1])


@app.route("/voices", methods=["GET"])
def mary_tts_api_voices():
"""MaryTTS-compatible /voices endpoint"""
# NOTE: We currently assume there is only one model active at the same time
if args.model_name is not None:
model_details = args.model_name.split("/")
else:
model_details = ["", "en", "", "default"]
return render_template_string(
"{{ name }} {{ locale }} {{ gender }}\n", name=model_details[3], locale=model_details[1], gender="u"
)


@app.route("/process", methods=["GET", "POST"])
def mary_tts_api_process():
"""MaryTTS-compatible /process endpoint"""
with lock:
if request.method == "POST":
data = parse_qs(request.get_data(as_text=True))
# NOTE: we ignore param. LOCALE and VOICE for now since we have only one active model
text = data.get("INPUT_TEXT", [""])[0]
else:
text = request.args.get("INPUT_TEXT", "")
print(f" > Model input: {text}")
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype="audio/wav")


def main():
app.run(debug=args.debug, host="::", port=args.port)

Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/configs/fastspeech2_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Fastspeech2Config(BaseTTSConfig):
base_model: str = "forward_tts"

# model specific params
model_args: ForwardTTSArgs = ForwardTTSArgs()
model_args: ForwardTTSArgs = ForwardTTSArgs(use_pitch=True, use_energy=True)

# multi-speaker settings
num_speakers: int = 0
Expand Down
18 changes: 10 additions & 8 deletions TTS/tts/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def samples(self, new_samples):
self._samples = new_samples
if hasattr(self, "f0_dataset"):
self.f0_dataset.samples = new_samples
if hasattr(self, "energy_dataset"):
self.energy_dataset.samples = new_samples
if hasattr(self, "phoneme_dataset"):
self.phoneme_dataset.samples = new_samples

Expand Down Expand Up @@ -856,11 +858,11 @@ def __init__(

def __getitem__(self, idx):
item = self.samples[idx]
energy = self.compute_or_load(item["audio_file"])
energy = self.compute_or_load(item["audio_file"], string2filename(item["audio_unique_name"]))
if self.normalize_energy:
assert self.mean is not None and self.std is not None, " [!] Mean and STD is not available"
energy = self.normalize(energy)
return {"audio_file": item["audio_file"], "energy": energy}
return {"audio_unique_name": item["audio_unique_name"], "energy": energy}

def __len__(self):
return len(self.samples)
Expand All @@ -884,7 +886,7 @@ def precompute(self, num_workers=0):

if self.normalize_energy:
computed_data = [tensor for batch in computed_data for tensor in batch] # flatten
energy_mean, energy_std = self.compute_pitch_stats(computed_data)
energy_mean, energy_std = self.compute_energy_stats(computed_data)
energy_stats = {"mean": energy_mean, "std": energy_std}
np.save(os.path.join(self.cache_path, "energy_stats"), energy_stats, allow_pickle=True)

Expand All @@ -900,7 +902,7 @@ def create_energy_file_path(wav_file, cache_path):
@staticmethod
def _compute_and_save_energy(ap, wav_file, energy_file=None):
wav = ap.load_wav(wav_file)
energy = calculate_energy(wav)
energy = calculate_energy(wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)
if energy_file:
np.save(energy_file, energy)
return energy
Expand Down Expand Up @@ -931,26 +933,26 @@ def denormalize(self, energy):
energy[zero_idxs] = 0.0
return energy

def compute_or_load(self, wav_file):
def compute_or_load(self, wav_file, audio_unique_name):
"""
compute energy and return a numpy array of energy values
"""
energy_file = self.create_Energy_file_path(wav_file, self.cache_path)
energy_file = self.create_energy_file_path(audio_unique_name, self.cache_path)
if not os.path.exists(energy_file):
energy = self._compute_and_save_energy(self.ap, wav_file, energy_file)
else:
energy = np.load(energy_file)
return energy.astype(np.float32)

def collate_fn(self, batch):
audio_file = [item["audio_file"] for item in batch]
audio_unique_name = [item["audio_unique_name"] for item in batch]
energys = [item["energy"] for item in batch]
energy_lens = [len(item["energy"]) for item in batch]
energy_lens_max = max(energy_lens)
energys_torch = torch.LongTensor(len(energys), energy_lens_max).fill_(self.get_pad_id())
for i, energy_len in enumerate(energy_lens):
energys_torch[i, :energy_len] = torch.LongTensor(energys[i])
return {"audio_file": audio_file, "energy": energys_torch, "energy_lens": energy_lens}
return {"audio_unique_name": audio_unique_name, "energy": energys_torch, "energy_lens": energy_lens}

def print_logs(self, level: int = 0) -> None:
indent = "\t" * level
Expand Down
16 changes: 10 additions & 6 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_aux_input(self, **kwargs) -> Dict:
"""Prepare and return `aux_input` used by `forward()`"""
return {"speaker_id": None, "style_wav": None, "d_vector": None, "language_id": None}

def get_aux_input_from_test_setences(self, sentence_info):
def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self.config, "model_args"):
config = self.config.model_args
else:
Expand All @@ -134,7 +134,7 @@ def get_aux_input_from_test_setences(self, sentence_info):

# get speaker id/d_vector
speaker_id, d_vector, language_id = None, None, None
if hasattr(self, "speaker_manager"):
if self.speaker_manager is not None:
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_embedding()
Expand All @@ -147,7 +147,7 @@ def get_aux_input_from_test_setences(self, sentence_info):
speaker_id = self.speaker_manager.name_to_id[speaker_name]

# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
if self.language_manager is not None and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.name_to_id[language_name]

return {
Expand Down Expand Up @@ -183,6 +183,7 @@ def format_batch(self, batch: Dict) -> Dict:
attn_mask = batch["attns"]
waveform = batch["waveform"]
pitch = batch["pitch"]
energy = batch["energy"]
language_ids = batch["language_ids"]
max_text_length = torch.max(text_lengths.float())
max_spec_length = torch.max(mel_lengths.float())
Expand Down Expand Up @@ -231,6 +232,7 @@ def format_batch(self, batch: Dict) -> Dict:
"item_idx": item_idx,
"waveform": waveform,
"pitch": pitch,
"energy": energy,
"language_ids": language_ids,
"audio_unique_names": batch["audio_unique_names"],
}
Expand Down Expand Up @@ -287,7 +289,7 @@ def get_data_loader(
loader = None
else:
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.name_to_id if config.model_args.use_speaker_embedding else None
Expand All @@ -302,7 +304,7 @@ def get_data_loader(
d_vector_mapping = None

# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
if self.language_manager is not None:
language_id_mapping = self.language_manager.name_to_id if self.args.use_language_embedding else None
else:
language_id_mapping = None
Expand All @@ -313,6 +315,8 @@ def get_data_loader(
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
compute_energy=config.get("compute_energy", False),
energy_cache_path=config.get("energy_cache_path", None),
samples=samples,
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
Expand Down Expand Up @@ -424,7 +428,7 @@ def on_init_start(self, trainer):
print(f" > `speakers.pth` is saved to {output_path}.")
print(" > `speakers_file` is updated in the config.json.")

if hasattr(self, "language_manager") and self.language_manager is not None:
if self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
Expand Down
19 changes: 0 additions & 19 deletions docs/source/_templates/page.html
Original file line number Diff line number Diff line change
@@ -1,23 +1,4 @@
{% extends "!page.html" %}
{% block scripts %}
{{ super() }}
<!-- DocsQA integration start -->
<script src="https://cdn.jsdelivr.net/npm/qabot@0.4"></script>

<qa-bot
token="qAFjWNovwHUXKKkVhy4AN6tawSwCMfdb3HJNPLVM23ACdrBGxmBNObM="
title="🐸💬TTS Bot"
description="A library for advanced Text-to-Speech generation"
style="bottom: calc(1.25em + 80px);"
>
<template>
<dl>
<dt>You can ask questions about TTS. Try</dt>
<dd>What is VITS?</dd>
<dd>How to train a TTS model?</dd>
<dd>What is the format of training data?</dd>
</dl>
</template>
</qa-bot>
<!-- DocsQA integration end -->
{% endblock %}
3 changes: 2 additions & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
formatting_your_dataset
what_makes_a_good_dataset
tts_datasets
marytts
.. toctree::
:maxdepth: 2
Expand All @@ -48,10 +49,10 @@
models/vits.md
models/forward_tts.md
models/tacotron1-2.md
models/overflow.md
.. toctree::
:maxdepth: 2
:caption: `vocoder` Models
```

Empty file added docs/source/marytts.md
Empty file.
36 changes: 36 additions & 0 deletions docs/source/models/overflow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Overflow TTS

Neural HMMs are a type of neural transducer recently proposed for
sequence-to-sequence modelling in text-to-speech. They combine the best features
of classic statistical speech synthesis and modern neural TTS, requiring less
data and fewer training updates, and are less prone to gibberish output caused
by neural attention failures. In this paper, we combine neural HMM TTS with
normalising flows for describing the highly non-Gaussian distribution of speech
acoustics. The result is a powerful, fully probabilistic model of durations and
acoustics that can be trained using exact maximum likelihood. Compared to
dominant flow-based acoustic models, our approach integrates autoregression for
improved modelling of long-range dependences such as utterance-level prosody.
Experiments show that a system based on our proposal gives more accurate
pronunciations and better subjective speech quality than comparable methods,
whilst retaining the original advantages of neural HMMs. Audio examples and code
are available at https://shivammehta25.github.io/OverFlow/.


## Important resources & papers
- HMM: https://de.wikipedia.org/wiki/Hidden_Markov_Model
- OverflowTTS paper: https://arxiv.org/abs/2211.06892
- Neural HMM: https://arxiv.org/abs/2108.13320
- Audio Samples: https://shivammehta25.github.io/OverFlow/


## OverflowConfig
```{eval-rst}
.. autoclass:: TTS.tts.configs.overflow_config.OverflowConfig
:members:
```

## Overflow Model
```{eval-rst}
.. autoclass:: TTS.tts.models.overflow.Overflow
:members:
```
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# core deps
numpy==1.21.6;python_version<"3.10"
numpy==1.22.4;python_version=="3.10"
numpy;python_version=="3.10"
cython==0.29.28
scipy>=1.4.0
torch>=1.7
Expand Down Expand Up @@ -39,4 +39,4 @@ gruut[de]==2.2.3
# deps for korean
jamo
nltk
g2pkk>=0.1.1
g2pkk>=0.1.1
2 changes: 1 addition & 1 deletion tests/tts_tests/test_fastspeech_2_speaker_emb_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
Expand Down
2 changes: 1 addition & 1 deletion tests/tts_tests/test_fastspeech_2_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
f0_cache_path="tests/data/ljspeech/f0_cache/",
compute_f0=True,
compute_energy=True,
energy_cache_path="tests/data/ljspeech/f0_cache/",
energy_cache_path="tests/data/ljspeech/energy_cache/",
run_eval=True,
test_delay_epochs=-1,
epochs=1,
Expand Down

0 comments on commit 9bb62c5

Please sign in to comment.