Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Numba FP16 RNNT Loss (#6991) #7038

Merged
merged 1 commit into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTLossPytorch
from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType
from nemo.core.utils import numba_utils
from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE
from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE
from nemo.utils import logging, model_utils
from nemo.utils import logging, logging_mode, model_utils

try:
import warprnnt_pytorch as warprnnt
Expand Down Expand Up @@ -98,7 +99,7 @@ class RNNTLossConfig:
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=True,
force_float32=not numba_utils.NUMBA_FP16_SUPPORTED,
),
"pytorch": RNNTLossConfig(
loss_name="pytorch",
Expand Down Expand Up @@ -387,7 +388,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
for the standard "blank" symbol. In particular, say V is the number of non-blank tokens in
the vocabulary, then in the case of,
standard RNNT: num_classes = V
multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before
standard blank, and the standard blank is the last symbol in the vocab)
TDT: num_classes = V. Note, V here does not include any of the "duration outputs".

Expand All @@ -413,6 +414,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str =
self.reduction = reduction
self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs)
self._force_float32 = RNNT_LOSS_RESOLVER[loss_name].force_float32
self._fp16_compat_checked = False

def reduce(self, losses, target_lengths):

Expand Down Expand Up @@ -442,8 +444,22 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
max_targets_len = target_lengths.max()

# Force cast joint to float32
# TODO: Remove once Numba supports FP16
if self._force_float32 and log_probs.dtype != torch.float32:
if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED:
# Execute the kernel in fp16
pass
elif self._force_float32 and log_probs.dtype != torch.float32:
# Log just once if fp16 tensor was passed and fp16 Numba CUDA loss could not be used.
if log_probs.dtype == torch.float16 and not self._fp16_compat_checked:
_, reason = numba_utils.is_numba_cuda_fp16_supported(return_reason=True)
logging.warning(
f"Provided RNNT Joint tensor is of dtype {log_probs.dtype}, but RNNT loss could not be calculated "
f"in fp16 due to following reason stated below. Loss will be calculated in fp32. \n\n"
f"{reason}",
mode=logging_mode.ONCE,
)
self._fp16_compat_checked = True

# Upcast the activation tensor and compute loss and grads in fp32
logits_orig = log_probs
log_probs = log_probs.float()
del logits_orig # save memory *before* computing the loss
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/losses/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@ def __init__(self, blank, reduction):
self.reduction = reduction

def forward(self, acts, labels, act_lens, label_lens):
# CPU patch for FP16
if not acts.is_cuda and acts.dtype == torch.float16:
acts = acts.float()

acts = torch.log_softmax(acts, -1)

forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens)
losses = -forward_logprob
if self.reduction == 'mean_batch':
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def rnnt_loss_gpu(

# Select GPU index
cuda.select_device(acts.device.index)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False)
gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=torch.float32, requires_grad=False)

### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ###
acts, acts_shape = rnnt_helper.flatten_tensor(acts)
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,15 @@ def forward(self, acts, labels, act_lens, label_lens):
_assert_no_grad(label_lens)
certify_inputs(acts, labels, act_lens, label_lens)

# CPU Patch for fp16 - force cast to fp32
if not acts.is_cuda and acts.dtype == torch.float16:
acts = acts.float()

if self.clamp > 0.0:
acts = LogSoftmaxGradModification.apply(acts, self.clamp)

acts = torch.nn.functional.log_softmax(acts, -1)

return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda)


Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_
loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu
grads = torch.zeros_like(acts) if acts.requires_grad else None
minibatch_size = acts.size(0)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype)
costs = torch.zeros(minibatch_size, device=acts.device, dtype=torch.float32)

loss_func(
acts,
Expand Down Expand Up @@ -119,7 +119,6 @@ def forward(
label_lens: Tensor of (batch) containing label length of each example
fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to
FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.

durations: list of durations for TDT model, must include 0 and 1, e.g.
[0, 1, 2, 3, 4].
sigma: hyper-parameter for logit under-normalization method for training
Expand Down Expand Up @@ -417,6 +416,10 @@ def forward(self, acts, labels, act_lens, label_lens):
label_lens: Tensor of (batch) containing label length of each example
"""
if not acts.is_cuda:
# Force FP32 until log_softmax() is implemented for fp16 on CPU
if acts.dtype == torch.float16:
acts = acts.float()

# Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping
# *after* we have obtained the gradients of loss(logsoftmax()).
# This is highly wasteful since it requires a copy of the entire joint tensor which is expensive.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def cost_and_grad_kernel(
)

# Scale llForward by FastEmit lambda
llForward *= 1.0 + self.fastemit_lambda_
llBackward *= 1.0 + self.fastemit_lambda_
llForward += llForward * self.fastemit_lambda_
llBackward += llBackward * self.fastemit_lambda_

diff = (llForward - llBackward).abs()
if diff > 0.1:
Expand Down Expand Up @@ -300,6 +300,10 @@ def compute_betas_and_grads(
Returns:
Loglikelihood of the forward variable and inplace updates the grad tensor.
"""
# Patch for CPU + fp16
if log_probs.dtype == torch.float16 and not log_probs.is_cuda:
log_probs = log_probs.float()

idx = CpuRNNT_index(U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first)
betas[idx(T - 1, U - 1)] = log_probs[idx(T - 1, U - 1) * 2]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import math
from typing import Optional, Tuple

import numba
import torch
from numba import cuda

Expand Down Expand Up @@ -112,7 +113,7 @@ def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda
if idx < length:
copy_data_1d(source, dest, idx)
dest[idx] *= -1.0
dest[idx] *= 1.0 + fastemit_lambda
dest[idx] *= numba.float32(1.0 + fastemit_lambda)


def get_workspace_size(
Expand Down
36 changes: 36 additions & 0 deletions nemo/core/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import operator
import os

from typing import Tuple, Union

from nemo.utils import model_utils

# Prevent Numba CUDA logs from showing at info level
Expand All @@ -26,6 +28,11 @@
__NUMBA_DEFAULT_MINIMUM_VERSION__ = "0.53.0"
__NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)

__NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0"
NUMBA_FP16_SUPPORTED = model_utils.check_lib_version(
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
)[0]


NUMBA_INSTALLATION_MESSAGE = (
"Could not import `numba`.\n"
Expand Down Expand Up @@ -148,6 +155,35 @@ def numba_cuda_is_supported(min_version: str) -> bool:
return False


def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]:
"""
Utility method that returns a bool, stating if FP16 is supported for numba cuda kernels or not.

Returns:
bool, whether Numba CUDA will support fp16 or not.
"""
reason = ""
use_nvidia_binding = os.environ.get('NUMBA_CUDA_USE_NVIDIA_BINDING', None)
if use_nvidia_binding is not None:
use_nvidia_binding = use_nvidia_binding.lower() == "1"
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is available and set to `1`. "
else:
use_nvidia_binding = False
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`."

if NUMBA_FP16_SUPPORTED:
reason += f"Numba CUDA FP16 is supported in installed numba version."
else:
reason += f"Numba CUDA FP16 is not supported in installed numba version."

result = use_nvidia_binding and NUMBA_FP16_SUPPORTED

if return_reason:
return result, reason
else:
return result


def skip_numba_cuda_test_if_unsupported(min_version: str):
"""
Helper method to skip pytest test case if numba cuda is not supported.
Expand Down
Loading