Skip to content

Commit

Permalink
"Loop labels" greedy decoding: faster implementation (NVIDIA#8286)
Browse files Browse the repository at this point in the history
* Loop labels greedy decoding v2

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add comments. Clean up

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add comments

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add comments

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add tests for batched alignments

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Add comments

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Fix comment

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Fix test

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

* Method -> classmethod (self is not needed)

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>

---------

Signed-off-by: Vladimir Bataev <vbataev@nvidia.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Signed-off-by: Sasha Meister <ameister@nvidia.com>
  • Loading branch information
2 people authored and sashameister committed Feb 15, 2024
1 parent 47545aa commit 9771f3e
Show file tree
Hide file tree
Showing 7 changed files with 663 additions and 166 deletions.
53 changes: 46 additions & 7 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,21 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

return state_list

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking
torch.where(mask.unsqueeze(-1), src_states[0], dst_states[0], out=dst_states[0])

def batch_split_states(self, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]:
"""
Split states into a list of states.
Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class.
"""
return [sub_state.split(1, dim=0) for sub_state in batch_states]

def batch_copy_states(
self,
old_states: List[torch.Tensor],
Expand Down Expand Up @@ -790,31 +805,32 @@ def _predict_modules(
)
return layers

def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
def initialize_state(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Initialize the state of the RNN layers, with same dtype and device as input `y`.
Initialize the state of the LSTM layers, with same dtype and device as input `y`.
LSTM accepts a tuple of 2 tensors as a state.
Args:
y: A torch.Tensor whose device the generated states will be placed on.
Returns:
List of torch.Tensor, each of shape [L, B, H], where
Tuple of 2 tensors, each of shape [L, B, H], where
L = Number of RNN layers
B = Batch size
H = Hidden size of RNN.
"""
batch = y.size(0)
if self.random_state_sampling and self.training:
state = [
state = (
torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
torch.randn(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
]
)

else:
state = [
state = (
torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
torch.zeros(self.pred_rnn_layers, batch, self.pred_hidden, dtype=y.dtype, device=y.device),
]
)
return state

def score_hypothesis(
Expand Down Expand Up @@ -1030,6 +1046,29 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

return state_list

@classmethod
def batch_replace_states_mask(
cls,
src_states: Tuple[torch.Tensor, torch.Tensor],
dst_states: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[i][mask] = src_states[i][mask]`, but non-blocking
# we need to cast, since LSTM is calculated in fp16 even if autocast to bfloat16 is enabled
dtype = dst_states[0].dtype
torch.where(mask.unsqueeze(0).unsqueeze(-1), src_states[0].to(dtype), dst_states[0], out=dst_states[0])
torch.where(mask.unsqueeze(0).unsqueeze(-1), src_states[1].to(dtype), dst_states[1], out=dst_states[1])

def batch_split_states(
self, batch_states: Tuple[torch.Tensor, torch.Tensor]
) -> list[Tuple[torch.Tensor, torch.Tensor]]:
"""
Split states into a list of states.
Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class.
"""
return list(zip(batch_states[0].split(1, dim=1), batch_states[1].split(1, dim=1)))

def batch_copy_states(
self,
old_states: List[torch.Tensor],
Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,20 @@ def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List
"""
raise NotImplementedError()

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
raise NotImplementedError()

def batch_split_states(self, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]:
"""
Split states into a list of states.
Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class.
"""
raise NotImplementedError()

def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Concatenate a batch of decoder state to a packed state.
Expand Down
162 changes: 14 additions & 148 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from omegaconf import DictConfig, OmegaConf

from nemo.collections.asr.modules import rnnt_abstract
from nemo.collections.asr.parts.submodules.rnnt_loop_labels_computer import GreedyBatchedRNNTLoopLabelsComputer
from nemo.collections.asr.parts.utils import rnnt_utils
from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceMethodConfig, ConfidenceMethodMixin
from nemo.collections.common.parts.rnn import label_collate
Expand Down Expand Up @@ -600,10 +601,20 @@ def __init__(

# Depending on availability of `blank_as_pad` support
# switch between more efficient batch decoding technique
self._decoding_computer = None
if self.decoder.blank_as_pad:
if loop_labels:
# default (faster) algo: loop over labels
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_labels
self._decoding_computer = GreedyBatchedRNNTLoopLabelsComputer(
decoder=self.decoder,
joint=self.joint,
blank_index=self._blank_index,
max_symbols_per_step=self.max_symbols,
preserve_alignments=preserve_alignments,
preserve_frame_confidence=preserve_frame_confidence,
confidence_method_cfg=confidence_method_cfg,
)
else:
# previous algo: loop over frames
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
Expand Down Expand Up @@ -670,155 +681,10 @@ def _greedy_decode_blank_as_pad_loop_labels(
if partial_hypotheses is not None:
raise NotImplementedError("`partial_hypotheses` support is not implemented")

batch_size, max_time, _ = x.shape

x = self.joint.project_encoder(x) # do not recalculate joint projection, project only once

# Initialize empty hypotheses and all necessary tensors
batched_hyps = rnnt_utils.BatchedHyps(
batch_size=batch_size, init_length=max_time, device=x.device, float_dtype=x.dtype
)
time_indices = torch.zeros([batch_size], dtype=torch.long, device=device) # always of batch_size
active_indices = torch.arange(batch_size, dtype=torch.long, device=device) # initial: all indices
labels = torch.full([batch_size], fill_value=self._blank_index, dtype=torch.long, device=device)
state = None

# init additional structs for hypotheses: last decoder state, alignments, frame_confidence
last_decoder_state = [None for _ in range(batch_size)]

alignments: Optional[rnnt_utils.BatchedAlignments]
if self.preserve_alignments or self.preserve_frame_confidence:
alignments = rnnt_utils.BatchedAlignments(
batch_size=batch_size,
logits_dim=self.joint.num_classes_with_blank,
init_length=max_time * 2, # blank for each timestep + text tokens
device=x.device,
float_dtype=x.dtype,
store_alignments=self.preserve_alignments,
store_frame_confidence=self.preserve_frame_confidence,
)
else:
alignments = None

# loop while there are active indices
while (current_batch_size := active_indices.shape[0]) > 0:
# stage 1: get decoder (prediction network) output
if state is None:
# start of the loop, SOS symbol is passed into prediction network
decoder_output, state, *_ = self._pred_step(self._SOS, state, batch_size=current_batch_size)
else:
# pass the labels (found in the inner loop) to the prediction network
decoder_output, state, *_ = self._pred_step(labels.unsqueeze(1), state, batch_size=current_batch_size)
decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection

# stage 2: get joint output, iteratively seeking for non-blank labels
# blank label in `labels` tensor means "end of hypothesis" (for this index)
logits = (
self._joint_step_after_projection(
x[active_indices, time_indices[active_indices]].unsqueeze(1),
decoder_output,
log_normalize=True if self.preserve_frame_confidence else None,
)
.squeeze(1)
.squeeze(1)
)
scores, labels = logits.max(-1)

# search for non-blank labels using joint, advancing time indices for blank labels
# checking max_symbols is not needed, since we already forced advancing time indices for such cases
blank_mask = labels == self._blank_index
if alignments is not None:
alignments.add_results_(
active_indices=active_indices,
time_indices=time_indices[active_indices],
logits=logits if self.preserve_alignments else None,
labels=labels if self.preserve_alignments else None,
confidence=torch.tensor(self._get_confidence(logits), device=device)
if self.preserve_frame_confidence
else None,
)
# advance_mask is a mask for current batch for searching non-blank labels;
# each element is True if non-blank symbol is not yet found AND we can increase the time index
advance_mask = torch.logical_and(blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices]))
while advance_mask.any():
advance_indices = active_indices[advance_mask]
time_indices[advance_indices] += 1
logits = (
self._joint_step_after_projection(
x[advance_indices, time_indices[advance_indices]].unsqueeze(1),
decoder_output[advance_mask],
log_normalize=True if self.preserve_frame_confidence else None,
)
.squeeze(1)
.squeeze(1)
)
# get labels (greedy) and scores from current logits, replace labels/scores with new
# labels[advance_mask] are blank, and we are looking for non-blank labels
more_scores, more_labels = logits.max(-1)
labels[advance_mask] = more_labels
scores[advance_mask] = more_scores
if alignments is not None:
alignments.add_results_(
active_indices=advance_indices,
time_indices=time_indices[advance_indices],
logits=logits if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
confidence=torch.tensor(self._get_confidence(logits), device=device)
if self.preserve_frame_confidence
else None,
)
blank_mask = labels == self._blank_index
advance_mask = torch.logical_and(
blank_mask, (time_indices[active_indices] + 1 < out_len[active_indices])
)

# stage 3: filter labels and state, store hypotheses
# the only case, when there are blank labels in predictions - when we found the end for some utterances
if blank_mask.any():
non_blank_mask = ~blank_mask
labels = labels[non_blank_mask]
scores = scores[non_blank_mask]

# select states for hyps that became inactive (is it necessary?)
# this seems to be redundant, but used in the `loop_frames` output
inactive_global_indices = active_indices[blank_mask]
inactive_inner_indices = torch.arange(current_batch_size, device=device, dtype=torch.long)[blank_mask]
for idx, batch_idx in zip(inactive_global_indices.cpu().numpy(), inactive_inner_indices.cpu().numpy()):
last_decoder_state[idx] = self.decoder.batch_select_state(state, batch_idx)

# update active indices and state
active_indices = active_indices[non_blank_mask]
state = self.decoder.mask_select_states(state, non_blank_mask)
# store hypotheses
batched_hyps.add_results_(
active_indices, labels, time_indices[active_indices].clone(), scores,
)

# stage 4: to avoid looping, go to next frame after max_symbols emission
if self.max_symbols is not None:
# if labels are non-blank (not end-of-utterance), check that last observed timestep with label:
# if it is equal to the current time index, and number of observations is >= max_symbols, force blank
force_blank_mask = torch.logical_and(
torch.logical_and(
labels != self._blank_index,
batched_hyps.last_timestep_lasts[active_indices] >= self.max_symbols,
),
batched_hyps.last_timestep[active_indices] == time_indices[active_indices],
)
if force_blank_mask.any():
# forced blank is not stored in the alignments following the original implementation
time_indices[active_indices[force_blank_mask]] += 1 # emit blank => advance time indices
# elements with time indices >= out_len become inactive, remove them from batch
still_active_mask = time_indices[active_indices] < out_len[active_indices]
active_indices = active_indices[still_active_mask]
labels = labels[still_active_mask]
state = self.decoder.mask_select_states(state, still_active_mask)

batched_hyps, alignments, last_decoder_state = self._decoding_computer(x=x, out_len=out_len)
hyps = rnnt_utils.batched_hyps_to_hypotheses(batched_hyps, alignments)
# preserve last decoder state (is it necessary?)
for i, last_state in enumerate(last_decoder_state):
# assert last_state is not None
hyps[i].dec_state = last_state
for hyp, state in zip(hyps, self.decoder.batch_split_states(last_decoder_state)):
hyp.dec_state = state
return hyps

def _greedy_decode_blank_as_pad_loop_frames(
Expand Down
Loading

0 comments on commit 9771f3e

Please sign in to comment.