Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph RNNT: Grid- and Compose-Transducer. W-Transducer loss #6168

Merged
merged 53 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
a265fb2
Add graph transducer: grid and compose transducer
artbataev Mar 10, 2023
10bd265
Add w-transducer
artbataev Mar 10, 2023
a1edc5d
Add Graph-RNNT to RNNT loss resolver
artbataev Mar 10, 2023
549ffc1
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev Mar 22, 2023
5055d77
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev Mar 27, 2023
969b9f8
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev Apr 21, 2023
b86c4d2
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 7, 2023
187053f
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 17, 2023
6865d10
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 18, 2023
240a566
Temporary fix for k2 installation
artbataev May 18, 2023
e024776
Improve graph-transducer: docstrings, refactoring
artbataev May 18, 2023
646e0b8
Add graph-transducer losses to rnnt loss resolver
artbataev May 19, 2023
9e32838
Unify interface
artbataev May 19, 2023
8088e33
Test Graph-RNNT
artbataev May 19, 2023
202d581
Fix device usage in tests
artbataev May 19, 2023
efeeb7e
Fix W-Transducer, add tests
artbataev May 19, 2023
2c67b61
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 22, 2023
89b4c92
Switch to k2 original repo (fix for CUDA 12)
artbataev May 22, 2023
726e62f
Fix docstring
artbataev May 22, 2023
a0758bf
Fix docstring
artbataev May 22, 2023
f073e8f
Clean up W-Transducer
artbataev May 22, 2023
efcd5c8
Fix tests
artbataev May 22, 2023
06ec45c
Fix W-Transducer. Improve docstrings.
artbataev May 22, 2023
1726def
Improve docstrings.
artbataev May 22, 2023
7d62dd9
Add more tests
artbataev May 22, 2023
8f0852c
Add more tests for W-Transducer
artbataev May 22, 2023
6a0eb1d
Test with variable-size inputs
artbataev May 22, 2023
deb53ba
Fix failing test
artbataev May 22, 2023
cd21ad9
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 23, 2023
7da63b5
Add tests for temporal schemas
artbataev May 23, 2023
56e9eaf
Tests for Compose-Transducer
artbataev May 23, 2023
9b98d5b
Fix W-Transducer
artbataev May 23, 2023
98689c6
Refactoring
artbataev May 23, 2023
f17d644
Clean up k2 installation script
artbataev May 23, 2023
ee47808
Fix temporal schema
artbataev May 23, 2023
0e2f275
Fix temporal scheme for W-Transducer
artbataev May 23, 2023
a4439ca
Refactoring. Test Grid-Transducer for RNNT.
artbataev May 24, 2023
c58f8c8
W-Transducer: test grid.
artbataev May 24, 2023
41646dc
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 24, 2023
2761b12
Improve docstrings.
artbataev May 24, 2023
8bcbd54
Improve docstrings. Fix tests.
artbataev May 24, 2023
ccc014c
Improve docstrings
artbataev May 24, 2023
276751a
Fix loss name in RNN-T resolver
artbataev May 24, 2023
f185bc2
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 24, 2023
2330641
Add quotes to type annotations
artbataev May 24, 2023
6ed956e
scheme -> schema
artbataev May 24, 2023
f2f56aa
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 25, 2023
3a327cd
Merge branch 'main' into transucer_compose_grid_wildcard
artbataev May 26, 2023
7f43fa6
force_last -> force_final
artbataev May 26, 2023
7a75f37
Merge remote-tracking branch 'origin/transucer_compose_grid_wildcard'…
artbataev May 26, 2023
d32e0e3
Add comments in `forward` method
artbataev May 26, 2023
b0af4c2
W-Transducer: fix temporal scheme construction (remove redundant arc)
artbataev May 26, 2023
607601e
RNNT_LOSS_RESOLVER: set force_float32 explicitly for all losses
artbataev May 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import operator
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Callable, Dict, List, Optional, Set

import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType
from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE
from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE
from nemo.utils import logging, model_utils

Expand All @@ -54,6 +56,13 @@
except (ImportError, ModuleNotFoundError):
NUMBA_RNNT_AVAILABLE = False

try:
from nemo.collections.asr.parts.k2.graph_transducer import GraphRnntLoss
from nemo.collections.asr.parts.k2.w_transducer import GraphWTransducerLoss

K2_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
K2_AVAILABLE = False

WARP_RNNT_INSTALLATION_MESSAGE = (
"Could not import `warprnnt_pytorch`.\n"
Expand All @@ -71,6 +80,7 @@ class RNNTLossConfig:
is_available: bool = False
installation_msg: str = ""
min_version: Optional[str] = None
force_float32: bool = True # default True for now for all losses except graph-based


# Resolved list of available RNNT losses
Expand All @@ -80,34 +90,53 @@ class RNNTLossConfig:
lib_name="warprnnt_pytorch",
is_available=WARP_RNNT_AVAILABLE,
installation_msg=WARP_RNNT_INSTALLATION_MESSAGE,
force_float32=True,
),
"warprnnt_numba": RNNTLossConfig(
loss_name="warprnnt_numba",
lib_name="numba",
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
),
"pytorch": RNNTLossConfig(
loss_name="pytorch",
lib_name="torch",
min_version='0.0',
is_available=True,
installation_msg="Pure Pytorch implementation of RNN-T loss. Slow and for debugging purposes only.",
force_float32=True,
),
"multiblank_rnnt": RNNTLossConfig(
loss_name="multiblank_rnnt",
lib_name="numba",
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
),
"multiblank_rnnt_pytorch": RNNTLossConfig(
loss_name="pytorch",
lib_name="torch",
min_version='0.0',
is_available=True,
installation_msg="Pure Pytorch implementation of Multiblank RNN-T loss. Slow and for debugging purposes only.",
force_float32=True,
),
"graph_w_transducer": RNNTLossConfig(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the rest of the configs to explicitly show force fp32 is true

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed (I updated all the configs)

loss_name="graph_w_transducer",
lib_name="k2",
is_available=K2_AVAILABLE,
installation_msg=K2_INSTALLATION_MESSAGE,
force_float32=False,
),
"graph_rnnt": RNNTLossConfig(
loss_name="graph_rnnt",
lib_name="k2",
is_available=K2_AVAILABLE,
installation_msg=K2_INSTALLATION_MESSAGE,
force_float32=False,
),
}

Expand All @@ -123,6 +152,38 @@ def _warn_unused_additional_kwargs(loss_name, kwargs):
)


def _clean_kwargs(
loss_name: str, kwargs: Optional[Dict[str, Any]], init_method: Callable, ignore_params: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""
Cleans kwargs for the given loss function. Warn if there are unused kwargs.

Args:
loss_name: name of the loss function
kwargs: kwargs to clean
init_method: LossClass.__init__ method
ignore_params: set of argument names for init_method to ignore

Returns:
only used kwargs for the given `init_method`
"""
if not kwargs:
return {}
init_params = set(inspect.signature(init_method).parameters.keys()) - {"self"}
if ignore_params is not None:
init_params -= ignore_params
unused_kwargs = dict()
used_kwargs = dict()
for key, value in kwargs.items():
if key not in init_params:
unused_kwargs[key] = value
else:
used_kwargs[key] = value
if len(unused_kwargs) > 0:
_warn_unused_additional_kwargs(loss_name, unused_kwargs)
return used_kwargs


def resolve_rnnt_default_loss_name() -> str:
return RNNT_LOSS_RESOLVER['default'].loss_name

Expand Down Expand Up @@ -213,7 +274,12 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
blank=blank_idx, big_blank_durations=big_blank_durations, reduction='none', sigma=sigma
)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

elif loss_name == "graph_rnnt":
loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphRnntLoss.__init__, ignore_params={"blank"})
loss_func = GraphRnntLoss(blank=blank_idx, **loss_kwargs)
elif loss_name == "graph_w_transducer":
loss_kwargs = _clean_kwargs(loss_name, loss_kwargs, GraphWTransducerLoss.__init__, ignore_params={"blank"})
loss_func = GraphWTransducerLoss(blank=blank_idx, **loss_kwargs)
else:
raise ValueError(
f"Invalid value of `loss_name`: {loss_name}. Allowed loss names are :" f"{loss_function_names}"
Expand Down Expand Up @@ -302,6 +368,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
self._blank = num_classes
self.reduction = reduction
self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs)
self._force_float32 = RNNT_LOSS_RESOLVER[loss_name].force_float32

def reduce(self, losses, target_lengths):

Expand Down Expand Up @@ -332,7 +399,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):

# Force cast joint to float32
# TODO: Remove once Numba supports FP16
if log_probs.dtype != torch.float32:
if self._force_float32 and log_probs.dtype != torch.float32:
logits_orig = log_probs
log_probs = log_probs.float()
del logits_orig # save memory *before* computing the loss
Expand Down
Loading