Skip to content

Commit

Permalink
Graph RNNT: Grid- and Compose-Transducer. W-Transducer loss (NVIDIA#6168
Browse files Browse the repository at this point in the history
)

* add GraphTransducerLossBase abstract class with the interface for Graph-based loses
* add RNN-T implementation in GraphRnntLoss with tests
* add W-Transducer implementation in GraphWTransducerLoss with tests
* add GraphRnntLoss + GraphWTransducerLoss to RNN-T loss resolver

---------

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
  • Loading branch information
artbataev authored and hsiehjackson committed Jun 2, 2023
1 parent 2bac13d commit 5831405
Show file tree
Hide file tree
Showing 8 changed files with 1,796 additions and 13 deletions.
73 changes: 70 additions & 3 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import operator
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional, Set

import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType
from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE
from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE
from nemo.utils import logging, model_utils

Expand All @@ -54,6 +56,13 @@
except (ImportError, ModuleNotFoundError):
NUMBA_RNNT_AVAILABLE = False

try:
from nemo.collections.asr.parts.k2.graph_transducer import GraphRnntLoss
from nemo.collections.asr.parts.k2.w_transducer import GraphWTransducerLoss

K2_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
K2_AVAILABLE = False

WARP_RNNT_INSTALLATION_MESSAGE = (
"Could not import `warprnnt_pytorch`.\n"
Expand All @@ -71,6 +80,7 @@ class RNNTLossConfig:
is_available: bool = False
installation_msg: str = ""
min_version: Optional[str] = None
force_float32: bool = True # default True for now for all losses except graph-based


# Resolved list of available RNNT losses
Expand All @@ -80,34 +90,53 @@ class RNNTLossConfig:
lib_name="warprnnt_pytorch",
is_available=WARP_RNNT_AVAILABLE,
installation_msg=WARP_RNNT_INSTALLATION_MESSAGE,
force_float32=True,
),
"warprnnt_numba": RNNTLossConfig(
loss_name="warprnnt_numba",
lib_name="numba",
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
),
"pytorch": RNNTLossConfig(
loss_name="pytorch",
lib_name="torch",
min_version='0.0',
is_available=True,
installation_msg="Pure Pytorch implementation of RNN-T loss. Slow and for debugging purposes only.",
force_float32=True,
),
"multiblank_rnnt": RNNTLossConfig(
loss_name="multiblank_rnnt",
lib_name="numba",
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
),
"multiblank_rnnt_pytorch": RNNTLossConfig(
loss_name="pytorch",
lib_name="torch",
min_version='0.0',
is_available=True,
installation_msg="Pure Pytorch implementation of Multiblank RNN-T loss. Slow and for debugging purposes only.",
force_float32=True,
),
"graph_w_transducer": RNNTLossConfig(
loss_name="graph_w_transducer",
lib_name="k2",
is_available=K2_AVAILABLE,
installation_msg=K2_INSTALLATION_MESSAGE,
force_float32=False,
),
"graph_rnnt": RNNTLossConfig(
loss_name="graph_rnnt",
lib_name="k2",
is_available=K2_AVAILABLE,
installation_msg=K2_INSTALLATION_MESSAGE,
force_float32=False,
),
}

Expand All @@ -123,6 +152,38 @@ def _warn_unused_additional_kwargs(loss_name, kwargs):
)


def _clean_kwargs(
loss_name: str, kwargs: Optional[Dict[str, Any]], init_method: Callable, ignore_params: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""
Cleans kwargs for the given loss function. Warn if there are unused kwargs.
Args:
loss_name: name of the loss function
kwargs: kwargs to clean
init_method: LossClass.__init__ method
ignore_params: set of argument names for init_method to ignore
Returns:
only used kwargs for the given `init_method`
"""
if not kwargs:
return {}
init_params = set(inspect.signature(init_method).parameters.keys()) - {"self"}
if ignore_params is not None:
init_params -= ignore_params
unused_kwargs = dict()
used_kwargs = dict()
for key, value in kwargs.items():
if key not in init_params:
unused_kwargs[key] = value
else:
used_kwargs[key] = value
if len(unused_kwargs) > 0:
_warn_unused_additional_kwargs(loss_name, unused_kwargs)
return used_kwargs


def resolve_rnnt_default_loss_name() -> str:
return RNNT_LOSS_RESOLVER['default'].loss_name

Expand Down Expand Up @@ -213,7 +274,12 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
blank=blank_idx, big_blank_durations=big_blank_durations, reduction='none', sigma=sigma
)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

elif loss_name == "graph_rnnt":
loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphRnntLoss.__init__, ignore_params={"blank"})
loss_func = GraphRnntLoss(blank=blank_idx, **loss_kwargs)
elif loss_name == "graph_w_transducer":
loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphWTransducerLoss.__init__, ignore_params={"blank"})
loss_func = GraphWTransducerLoss(blank=blank_idx, **loss_kwargs)
else:
raise ValueError(
f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}"
Expand Down Expand Up @@ -302,6 +368,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
self._blank = num_classes
self.reduction = reduction
self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs)
self._force_float32 = RNNT_LOSS_RESOLVER[loss_name].force_float32

def reduce(self, losses, target_lengths):

Expand Down Expand Up @@ -332,7 +399,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):

# Force cast joint to float32
# TODO: Remove once Numba supports FP16
if log_probs.dtype != torch.float32:
if self._force_float32 and log_probs.dtype != torch.float32:
logits_orig = log_probs
log_probs = log_probs.float()
del logits_orig # save memory *before* computing the loss
Expand Down
Loading

0 comments on commit 5831405

Please sign in to comment.