Skip to content

Commit

Permalink
Add Tacotron2 + Waveglow TTS demo
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
rajeevsrao committed Jul 9, 2020
1 parent 7cb2863 commit 776651f
Show file tree
Hide file tree
Showing 53 changed files with 5,525 additions and 0 deletions.
110 changes: 110 additions & 0 deletions demo/Tacotron2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Tacotron 2 and WaveGlow Inference with TensorRT

The Tacotron2 and WaveGlow models form a text-to-speech (TTS) system that enables users to synthesize natural sounding speech from raw transcripts without any additional information such as patterns and/or rhythms of speech. This is an implementation of Tacotron2 for PyTorch, tested and maintained by NVIDIA, and provides scripts to perform high-performance inference using NVIDIA TensorRT. More information about the TTS system and its training can be found in the
[NVIDIA DeepLearningExamples](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/Tacotron2).

NVIDIA TensorRT is a platform for high-performance deep learning inference. It includes a deep learning inference optimizer and runtime that delivers low latency and high-throughput for deep learning inference applications. After optimizing the compute-intensive acoustic model with NVIDIA TensorRT, inference throughput increased by up to 1.4x over native PyTorch in mixed precision.

### Software Versions

Software version configuration tested for the instructions that follow:

|Software|Version|
|--------|-------|
|Python|3.6.9|
|CUDA|11.0.171|
|Apex|0.1|
|TensorRT|7.1.3.4|
|PyTorch|1.5.1|


## Quick Start Guide

1. Build and launch the container as described in [TensorRT OSS README](https://github.com/NVIDIA/TensorRT/blob/master/README.md).

**Note:** After this point, all commands should be run from within the container.

2. Install prerequisite software for TTS sample:
```bash
cd $TRT_SOURCE/demo/Tacotron2

export LD_LIBRARY_PATH=$TRT_SOURCE/build/out:/tensorrt/lib:$LD_LIBRARY_PATH
pip3 install /tensorrt/python/tensorrt-7.1*-cp36-none-linux_x86_64.whl

bash ./scripts/install_prerequisites.sh
```
3. Download pretrained checkpoints from [NGC](https://ngc.nvidia.com/catalog/models) into the `./checkpoints` directory:

- [Tacotron2 checkpoint](https://ngc.nvidia.com/models/nvidia:tacotron2pyt_fp16)
- [WaveGlow checkpoint](https://ngc.nvidia.com/models/nvidia:waveglow256pyt_fp16)

```bash
cd $TRT_SOURCE/demo/Tacotron2
./scripts/download_checkpoints.sh
```

4. Export the models to ONNX intermediate representation (ONNX IR).
Export Tacotron 2 to three ONNX parts: Encoder, Decoder, and Postnet:

```bash
mkdir -p output
python exports/export_tacotron2_onnx.py --tacotron2 ./checkpoints/tacotron2pyt_fp16_v3/tacotron2_1032590_6000_amp -o output/ --fp16
```

Export WaveGlow to ONNX IR:

```bash
python exports/export_waveglow_onnx.py --waveglow checkpoints/waveglow256pyt_fp16_v2/waveglow_1076430_14000_amp --wn-channels 256 -o output/ --fp16
```

After running the above commands, there should be four new ONNX files in `./output/` directory:
`encoder.onnx`, `decoder_iter.onnx`, `postnet.onnx`, and `waveglow.onnx`.

5. Export the ONNX IRs to TensorRT engines with fp16 mode enabled:

```bash
python trt/export_onnx2trt.py --encoder output/encoder.onnx --decoder output/decoder_iter.onnx --postnet output/postnet.onnx --waveglow output/waveglow.onnx -o output/ --fp16
```

After running the command, there should be four new engine files in `./output/` directory:
`encoder_fp16.engine`, `decoder_iter_fp16.engine`, `postnet_fp16.engine`, and `waveglow_fp16.engine`.

6. Run TTS inference pipeline with fp16:

```bash
python trt/inference_trt.py -i phrases/phrase.txt --encoder output/encoder_fp16.engine --decoder output/decoder_iter_fp16.engine --postnet output/postnet_fp16.engine --waveglow output/waveglow_fp16.engine -o output/ --fp16
```

## Performance

### Benchmarking

The following section shows how to benchmark the TensorRT inference performance for our Tacotron2 + Waveglow TTS.

#### TensorRT inference benchmark

Before running the benchmark script, please download the checkpoints and build the TensorRT engines for the Tacotron2 and Waveglow models as prescribed in the [Quick Start Guide](#quick-start-guide) above.

The inference benchmark is performed on a single GPU by the `inference_benchmark.sh` script, which runs 3 warm-up iterations then runs timed inference for 1000 iterations.

```bash
bash scripts/inference_benchmark.sh
```

*Note*: For benchmarking we use WaveGlow with 256 residual channels.

### Results

#### Inference performance: NVIDIA T4 (16GB)

|Framework|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 90% (s)|Latency confidence interval 95% (s)|Latency confidence interval 99% (s)|Throughput (samples/sec)|Speed-up PyT+TRT/TRT|Avg mels generated (81 mels=1 sec of speech)| Avg audio length (s)| Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|PyT+TRT|1| 128| FP16| 0.93| 0.15| 1.09| 1.13| 1.49| 169,104| 1.78| 602| 7.35| 7.9|
|PyT |1| 128| FP16| 1.58| 0.07| 1.65| 1.70| 1.76| 97,991| 1.00| 605| 6.94| 4.4|

#### Inference performance: NVIDIA V100 (16GB)

|Framework|Batch size|Input length|Precision|Avg latency (s)|Latency std (s)|Latency confidence interval 90% (s)|Latency confidence interval 95% (s)|Latency confidence interval 99% (s)|Throughput (samples/sec)|Speed-up PyT+TRT/TRT|Avg mels generated (81 mels=1 sec of speech)| Avg audio length (s)| Avg RTF|
|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|
|PyT+TRT|1| 128| FP16| 0.63| 0.02| 0.65| 0.66| 0.67| 242,466| 1.78| 599| 7.09| 10.9|
|PyT |1| 128| FP16| 1.13| 0.03| 1.17| 1.17| 1.21| 136,160| 1.00| 602| 7.10| 6.3|
107 changes: 107 additions & 0 deletions demo/Tacotron2/common/audio_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import numpy as np
from scipy.signal import get_window
import librosa.util as librosa_util


def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
n_fft=800, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft

n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)

# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)

# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x


def griffin_lim(magnitudes, stft_fn, n_iters=30):
"""
PARAMS
------
magnitudes: spectrogram magnitudes
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
"""

angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
angles = angles.astype(np.float32)
angles = torch.autograd.Variable(torch.from_numpy(angles))
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)

for i in range(n_iters):
_, angles = stft_fn.transform(signal)
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
return signal


def dynamic_range_compression(x, C=1, clip_val=1e-5):
"""
PARAMS
------
C: compression factor
"""
return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression(x, C=1):
"""
PARAMS
------
C: compression factor used to compress
"""
return torch.exp(x) / C
93 changes: 93 additions & 0 deletions demo/Tacotron2/common/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from librosa.filters import mel as librosa_mel_fn
from common.audio_processing import dynamic_range_compression, dynamic_range_decompression
from common.stft import STFT


class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))

def forward(self, x):
return self.linear_layer(x)


class ConvNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=None, dilation=1, bias=True, w_init_gain='linear'):
super(ConvNorm, self).__init__()
if padding is None:
assert(kernel_size % 2 == 1)
padding = int(dilation * (kernel_size - 1) / 2)

self.conv = torch.nn.Conv1d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
bias=bias)

torch.nn.init.xavier_uniform_(
self.conv.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))

def forward(self, signal):
return self.conv(signal)


class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.stft_fn = STFT(filter_length, hop_length, win_length)
mel_basis = librosa_mel_fn(
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
mel_basis = torch.from_numpy(mel_basis).float()
self.register_buffer('mel_basis', mel_basis)

def spectral_normalize(self, magnitudes):
output = dynamic_range_compression(magnitudes)
return output

def spectral_de_normalize(self, magnitudes):
output = dynamic_range_decompression(magnitudes)
return output

def mel_spectrogram(self, y):
"""Computes mel-spectrograms from a batch of waves
PARAMS
------
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
RETURNS
-------
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
"""
assert(torch.min(y.data) >= -1)
assert(torch.max(y.data) <= 1)

magnitudes, phases = self.stft_fn.transform(y)
magnitudes = magnitudes.data
mel_output = torch.matmul(self.mel_basis, magnitudes)
mel_output = self.spectral_normalize(mel_output)
return mel_output
Loading

0 comments on commit 776651f

Please sign in to comment.