Skip to content

Commit

Permalink
[ASR] RNN-T greedy decoding max_frames fix for alignment and confiden…
Browse files Browse the repository at this point in the history
…ce (#7635)

* decoding and test fix

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Aleksandr Laptev <alaptev@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and yaoyu-33 committed Oct 13, 2023
1 parent 1e4c2b2 commit 798f6fc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 20 deletions.
31 changes: 18 additions & 13 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def _greedy_decode_blank_as_pad(

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
blank_mask_prev = None

# Get max sequence length
max_out_len = out_len.max()
Expand All @@ -666,6 +667,8 @@ def _greedy_decode_blank_as_pad(
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len
blank_mask_prev = blank_mask.clone()

# Start inner loop
while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
# Batch prediction and joint network steps
Expand Down Expand Up @@ -694,7 +697,6 @@ def _greedy_decode_blank_as_pad(
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)
all_blanks = torch.all(blank_mask)

del k_is_blank

Expand All @@ -705,10 +707,9 @@ def _greedy_decode_blank_as_pad(
logp_vals = logp.to('cpu')
logp_ids = logp_vals.max(1)[1]
for batch_idx, is_blank in enumerate(blank_mask):
# we only want to update non-blanks, unless we are at the last step in the loop where
# all elements produced blanks, otherwise there will be duplicate predictions
# saved in alignments
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
# we only want to update non-blanks and first-time blanks,
# otherwise alignments will contain duplicate predictions
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].alignments[-1].append(
(logp_vals[batch_idx], logp_ids[batch_idx])
)
Expand All @@ -720,13 +721,15 @@ def _greedy_decode_blank_as_pad(
# Insert probabilities into last timestep per sample
confidence = self._get_confidence(logp)
for batch_idx, is_blank in enumerate(blank_mask):
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx])
del logp

blank_mask_prev.bitwise_or_(blank_mask)

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if all_blanks:
if blank_mask.all():
not_blank = False
else:
# Collect batch indices where blanks occurred now/past
Expand Down Expand Up @@ -847,6 +850,7 @@ def _greedy_decode_masked(

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
blank_mask_prev = None

# Get max sequence length
max_out_len = out_len.max()
Expand All @@ -866,6 +870,7 @@ def _greedy_decode_masked(
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len
blank_mask_prev = blank_mask.clone()

# Start inner loop
while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
Expand Down Expand Up @@ -904,7 +909,6 @@ def _greedy_decode_masked(
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)
all_blanks = torch.all(blank_mask)

# If preserving alignments, check if sequence length of sample has been reached
# before adding alignment
Expand All @@ -913,10 +917,9 @@ def _greedy_decode_masked(
logp_vals = logp.to('cpu')
logp_ids = logp_vals.max(1)[1]
for batch_idx, is_blank in enumerate(blank_mask):
# we only want to update non-blanks, unless we are at the last step in the loop where
# all elements produced blanks, otherwise there will be duplicate predictions
# saved in alignments
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
# we only want to update non-blanks and first-time blanks,
# otherwise alignments will contain duplicate predictions
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].alignments[-1].append(
(logp_vals[batch_idx], logp_ids[batch_idx])
)
Expand All @@ -929,10 +932,12 @@ def _greedy_decode_masked(
# Insert probabilities into last timestep per sample
confidence = self._get_confidence(logp)
for batch_idx, is_blank in enumerate(blank_mask):
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx])
del logp

blank_mask_prev.bitwise_or_(blank_mask)

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if blank_mask.all():
Expand Down
24 changes: 17 additions & 7 deletions tests/collections/asr/test_asr_rnnt_encdec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ def predict(
add_sos: bool = False,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
if batch_size is None:
batch_size = 1
if y is not None:
y = y + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size())
if y is not None and state is not None:
return (y + state) / 2, y * state
elif state is not None:
return torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(state.size()), state
elif y is not None:
return y, torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size())
return (
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, 1, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, 1, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]),
)

def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
Expand All @@ -66,8 +70,11 @@ def score_hypothesis(

def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]:
if batch_states is not None:
states = batch_states[0][idx]
states = states.long()
try:
states = batch_states[0][idx]
states = states.long()
except Exception as e:
raise Exception(batch_states, idx)
return [states]
else:
return None
Expand All @@ -92,8 +99,12 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
setup["decoder"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=True)
setup["decoder_masked"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=False)
setup["joint"] = DummyRNNTJoint()
setup["encoder_output"] = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32).transpose(1, 2)
setup["encoded_lengths"] = torch.tensor([3])
# expected timesteps for max_symbols_per_step=5 are [[0, 0, 0, 0, 0, 1, 1], [1, 1, 1, 1, 1]],
# so we have both looped and regular iteration on the second frame
setup["encoder_output"] = torch.tensor(
[[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 1], [2, 0, 0], [0, 0, 0]]], dtype=torch.float32
).transpose(1, 2)
setup["encoded_lengths"] = torch.tensor([3, 2])
return setup


Expand Down Expand Up @@ -726,7 +737,6 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class):
decoder,
joint_net,
blank_index=len(token_list),
preserve_alignments=True,
preserve_frame_confidence=True,
max_symbols_per_step=max_symbols_per_step,
)
Expand Down

0 comments on commit 798f6fc

Please sign in to comment.