Skip to content

Commit

Permalink
Move RNNT Loss out of prototype (pytorch#1711)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caroline Chen authored Aug 19, 2021
1 parent b7d44d9 commit 2c11582
Show file tree
Hide file tree
Showing 40 changed files with 513 additions and 653 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ endif()
# Options
option(BUILD_SOX "Build libsox statically" OFF)
option(BUILD_KALDI "Build kaldi statically" ON)
option(BUILD_RNNT "Enable RNN transducer" OFF)
option(BUILD_RNNT "Enable RNN transducer" ON)
option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON)
option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF)
option(USE_CUDA "Enable CUDA support" OFF)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ conda install -y -c pytorch-nightly torchaudio

The build process builds libsox and some codecs that torchaudio need to link to. This is achieved by setting the environment variable `BUILD_SOX=1`.
The build process will fetch and build libmad, lame, flac, vorbis, opus, and libsox before building extension. This process requires `cmake` and `pkg-config`.
The build process also builds the RNN transducer loss. This functionality can be disabled by setting the environment variable `BUILD_RNNT=0`.

```bash
# Linux
Expand Down
2 changes: 1 addition & 1 deletion build_tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _get_build(var, default=False):

_BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX")
_BUILD_KALDI = False if platform.system() == 'Windows' else _get_build("BUILD_KALDI", True)
_BUILD_RNNT = _get_build("BUILD_RNNT")
_BUILD_RNNT = _get_build("BUILD_RNNT", True)
_USE_ROCM = _get_build("USE_ROCM")
_USE_CUDA = _get_build("USE_CUDA", torch.cuda.is_available())

Expand Down
8 changes: 8 additions & 0 deletions docs/source/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ vad

.. autofunction:: spectral_centroid

:hidden:`Loss`
~~~~~~~~~~~~~~

rnnt_loss
---------

.. autofunction:: rnnt_loss

References
~~~~~~~~~~

Expand Down
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
compliance.kaldi
kaldi_io
utils
rnnt_loss


.. toctree::
Expand Down
28 changes: 0 additions & 28 deletions docs/source/rnnt_loss.rst

This file was deleted.

6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ Transforms are common audio transforms. They can be chained together using :clas

.. automethod:: forward

:hidden:`RNNTLoss`
~~~~~~~~~~~~~~~~~~

.. autoclass:: RNNTLoss

.. automethod:: forward

References
~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion examples/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SET(BUILD_LIBTORCHAUDIO ON CACHE BOOL "Build libtorchaudio")
SET(BUILD_SOX ON CACHE BOOL "Build libsox into libtorchaudio")

SET(BUILD_KALDI OFF CACHE BOOL "Build Kaldi into libtorchaudio")
SET(BUILD_RNNT OFF CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_RNNT ON CACHE BOOL "Build RNN transducer into libtorchaudio")
SET(BUILD_TORCHAUDIO_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding")

find_package(Torch REQUIRED)
Expand Down
1 change: 1 addition & 0 deletions examples/libtorchaudio/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ cmake -GNinja \
-DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \
-DBUILD_SOX=ON \
-DBUILD_KALDI=OFF \
-DBUILD_RNNT=ON \
..
cmake --build .
```
Expand Down
2 changes: 1 addition & 1 deletion packaging/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ if [[ "$OSTYPE" == "msys" ]]; then
python_tag="$(echo "cp$PYTHON_VERSION" | tr -d '.')"
"$script_dir/vc_env_helper.bat" python setup.py bdist_wheel --plat-name win_amd64 --python-tag $python_tag
else
BUILD_RNNT=1 BUILD_SOX=1 python setup.py bdist_wheel
BUILD_SOX=1 python setup.py bdist_wheel
fi
2 changes: 1 addition & 1 deletion packaging/torchaudio/build.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env bash
set -ex

BUILD_RNNT=1 BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
BUILD_SOX=1 python setup.py install --single-version-externally-managed --record=record.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,174 @@
import unittest
import random
import torch
from torchaudio.prototype.rnnt_loss import RNNTLoss

from .numpy_transducer import NumpyTransducerLoss
import numpy as np
from torchaudio.functional import rnnt_loss


class _NumpyTransducer(torch.autograd.Function):
@staticmethod
def forward(
ctx,
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
device = log_probs.device
log_probs = log_probs.cpu().data.numpy()
logit_lengths = logit_lengths.cpu().data.numpy()
target_lengths = target_lengths.cpu().data.numpy()
targets = targets.cpu().data.numpy()

gradients, costs, _, _ = __class__.compute(
log_probs=log_probs,
logit_lengths=logit_lengths,
target_lengths=target_lengths,
targets=targets,
blank=blank,
)

costs = torch.FloatTensor(costs).to(device=device)
gradients = torch.FloatTensor(gradients).to(device=device)
ctx.grads = torch.autograd.Variable(gradients)

return costs

@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.view(-1, 1, 1, 1).to(ctx.grads)
return ctx.grads.mul(grad_output), None, None, None, None, None, None, None, None

@staticmethod
def compute_alpha_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
alpha = np.zeros((max_T, max_U), dtype=np.float32)
for t in range(1, max_T):
alpha[t, 0] = alpha[t - 1, 0] + log_probs[t - 1, 0, blank]

for u in range(1, max_U):
alpha[0, u] = alpha[0, u - 1] + log_probs[0, u - 1, targets[u - 1]]

for t in range(1, max_T):
for u in range(1, max_U):
skip = alpha[t - 1, u] + log_probs[t - 1, u, blank]
emit = alpha[t, u - 1] + log_probs[t, u - 1, targets[u - 1]]
alpha[t, u] = np.logaddexp(skip, emit)

cost = -(alpha[-1, -1] + log_probs[-1, -1, blank])
return alpha, cost

@staticmethod
def compute_beta_one_sequence(log_probs, targets, blank=-1):
max_T, max_U, D = log_probs.shape
beta = np.zeros((max_T, max_U), dtype=np.float32)
beta[-1, -1] = log_probs[-1, -1, blank]

for t in reversed(range(max_T - 1)):
beta[t, -1] = beta[t + 1, -1] + log_probs[t, -1, blank]

for u in reversed(range(max_U - 1)):
beta[-1, u] = beta[-1, u + 1] + log_probs[-1, u, targets[u]]

for t in reversed(range(max_T - 1)):
for u in reversed(range(max_U - 1)):
skip = beta[t + 1, u] + log_probs[t, u, blank]
emit = beta[t, u + 1] + log_probs[t, u, targets[u]]
beta[t, u] = np.logaddexp(skip, emit)

cost = -beta[0, 0]
return beta, cost

@staticmethod
def compute_gradients_one_sequence(
log_probs, alpha, beta, targets, blank=-1
):
max_T, max_U, D = log_probs.shape
gradients = np.full(log_probs.shape, float("-inf"))
cost = -beta[0, 0]

gradients[-1, -1, blank] = alpha[-1, -1]

gradients[:-1, :, blank] = alpha[:-1, :] + beta[1:, :]

for u, l in enumerate(targets):
gradients[:, u, l] = alpha[:, u] + beta[:, u + 1]

gradients = -(np.exp(gradients + log_probs + cost))
return gradients

@staticmethod
def compute(
log_probs,
logit_lengths,
target_lengths,
targets,
blank=-1,
):
gradients = np.zeros_like(log_probs)
B_tgt, max_T, max_U, D = log_probs.shape
B_src = logit_lengths.shape[0]

H = int(B_tgt / B_src)

alphas = np.zeros((B_tgt, max_T, max_U))
betas = np.zeros((B_tgt, max_T, max_U))
betas.fill(float("-inf"))
alphas.fill(float("-inf"))
costs = np.zeros(B_tgt)
for b_tgt in range(B_tgt):
b_src = int(b_tgt / H)
T = int(logit_lengths[b_src])
# NOTE: see https://arxiv.org/pdf/1211.3711.pdf Section 2.1
U = int(target_lengths[b_tgt]) + 1

seq_log_probs = log_probs[b_tgt, :T, :U, :]
seq_targets = targets[b_tgt, : int(target_lengths[b_tgt])]
alpha, alpha_cost = __class__.compute_alpha_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)

beta, beta_cost = __class__.compute_beta_one_sequence(
log_probs=seq_log_probs, targets=seq_targets, blank=blank
)

seq_gradients = __class__.compute_gradients_one_sequence(
log_probs=seq_log_probs,
alpha=alpha,
beta=beta,
targets=seq_targets,
blank=blank,
)
np.testing.assert_almost_equal(alpha_cost, beta_cost, decimal=2)
gradients[b_tgt, :T, :U, :] = seq_gradients
costs[b_tgt] = beta_cost
alphas[b_tgt, :T, :U] = alpha
betas[b_tgt, :T, :U] = beta

return gradients, costs, alphas, betas


class NumpyTransducerLoss(torch.nn.Module):
def __init__(self, blank=-1):
super().__init__()
self.blank = blank

def forward(
self,
logits,
logit_lengths,
target_lengths,
targets,
):
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
return _NumpyTransducer.apply(
log_probs,
logit_lengths,
target_lengths,
targets,
self.blank,
)


def compute_with_numpy_transducer(data):
Expand All @@ -24,14 +189,13 @@ def compute_with_numpy_transducer(data):


def compute_with_pytorch_transducer(data):
costs = RNNTLoss(
blank=data["blank"],
reduction="none",
)(
costs = rnnt_loss(
logits=data["logits"],
logit_lengths=data["logit_lengths"],
target_lengths=data["target_lengths"],
targets=data["targets"],
blank=data["blank"],
reduction="none",
)

loss = torch.sum(costs)
Expand Down
7 changes: 6 additions & 1 deletion test/torchaudio_unittest/functional/autograd_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import torch
from .autograd_impl import Autograd
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils


class TestAutogradLfilterCPU(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')


class TestAutogradRNNTCPU(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
8 changes: 7 additions & 1 deletion test/torchaudio_unittest/functional/autograd_cuda_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import torch
from .autograd_impl import Autograd
from .autograd_impl import Autograd, AutogradFloat32
from torchaudio_unittest import common_utils


@common_utils.skipIfNoCuda
class TestAutogradLfilterCUDA(Autograd, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')


@common_utils.skipIfNoCuda
class TestAutogradRNNTCUDA(AutogradFloat32, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
Loading

0 comments on commit 2c11582

Please sign in to comment.