Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ASR] RNN-T greedy decoding max_frames fix for alignment and confidence #7635

Merged
merged 2 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def _greedy_decode_blank_as_pad(

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
blank_mask_prev = None

# Get max sequence length
max_out_len = out_len.max()
Expand All @@ -666,6 +667,8 @@ def _greedy_decode_blank_as_pad(
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len
blank_mask_prev = blank_mask.clone()

# Start inner loop
while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
# Batch prediction and joint network steps
Expand Down Expand Up @@ -694,7 +697,6 @@ def _greedy_decode_blank_as_pad(
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)
all_blanks = torch.all(blank_mask)

del k_is_blank

Expand All @@ -705,10 +707,9 @@ def _greedy_decode_blank_as_pad(
logp_vals = logp.to('cpu')
logp_ids = logp_vals.max(1)[1]
for batch_idx, is_blank in enumerate(blank_mask):
# we only want to update non-blanks, unless we are at the last step in the loop where
# all elements produced blanks, otherwise there will be duplicate predictions
# saved in alignments
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
# we only want to update non-blanks and first-time blanks,
# otherwise alignments will contain duplicate predictions
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].alignments[-1].append(
(logp_vals[batch_idx], logp_ids[batch_idx])
)
Expand All @@ -720,13 +721,15 @@ def _greedy_decode_blank_as_pad(
# Insert probabilities into last timestep per sample
confidence = self._get_confidence(logp)
for batch_idx, is_blank in enumerate(blank_mask):
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx])
del logp

blank_mask_prev.bitwise_or_(blank_mask)

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if all_blanks:
if blank_mask.all():
not_blank = False
else:
# Collect batch indices where blanks occurred now/past
Expand Down Expand Up @@ -847,6 +850,7 @@ def _greedy_decode_masked(

# Mask buffers
blank_mask = torch.full([batchsize], fill_value=0, dtype=torch.bool, device=device)
blank_mask_prev = None

# Get max sequence length
max_out_len = out_len.max()
Expand All @@ -866,6 +870,7 @@ def _greedy_decode_masked(
# Batch: [B, T, D], but Bi may have seq len < max(seq_lens_in_batch)
# Forcibly mask with "blank" tokens, for all sample where current time step T > seq_len
blank_mask = time_idx >= out_len
blank_mask_prev = blank_mask.clone()

# Start inner loop
while not_blank and (self.max_symbols is None or symbols_added < self.max_symbols):
Expand Down Expand Up @@ -904,7 +909,6 @@ def _greedy_decode_masked(
# This is accumulating blanks over all time steps T and all target steps min(max_symbols, U)
k_is_blank = k == self._blank_index
blank_mask.bitwise_or_(k_is_blank)
all_blanks = torch.all(blank_mask)

# If preserving alignments, check if sequence length of sample has been reached
# before adding alignment
Expand All @@ -913,10 +917,9 @@ def _greedy_decode_masked(
logp_vals = logp.to('cpu')
logp_ids = logp_vals.max(1)[1]
for batch_idx, is_blank in enumerate(blank_mask):
# we only want to update non-blanks, unless we are at the last step in the loop where
# all elements produced blanks, otherwise there will be duplicate predictions
# saved in alignments
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
# we only want to update non-blanks and first-time blanks,
# otherwise alignments will contain duplicate predictions
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].alignments[-1].append(
(logp_vals[batch_idx], logp_ids[batch_idx])
)
Expand All @@ -929,10 +932,12 @@ def _greedy_decode_masked(
# Insert probabilities into last timestep per sample
confidence = self._get_confidence(logp)
for batch_idx, is_blank in enumerate(blank_mask):
if time_idx < out_len[batch_idx] and (all_blanks or not is_blank):
if time_idx < out_len[batch_idx] and (not blank_mask_prev[batch_idx] or not is_blank):
hypotheses[batch_idx].frame_confidence[-1].append(confidence[batch_idx])
del logp

blank_mask_prev.bitwise_or_(blank_mask)

# If all samples predict / have predicted prior blanks, exit loop early
# This is equivalent to if single sample predicted k
if blank_mask.all():
Expand Down
24 changes: 17 additions & 7 deletions tests/collections/asr/test_asr_rnnt_encdec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ def predict(
add_sos: bool = False,
batch_size: Optional[int] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
if batch_size is None:
batch_size = 1
if y is not None:
y = y + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size())
if y is not None and state is not None:
return (y + state) / 2, y * state
elif state is not None:
return torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(state.size()), state
elif y is not None:
return y, torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat(y.size())
return (
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, 1, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, 1, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]),
torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32).repeat([1, batch_size, 1]),
)

def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
Expand All @@ -66,8 +70,11 @@ def score_hypothesis(

def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> List[List[torch.Tensor]]:
if batch_states is not None:
states = batch_states[0][idx]
states = states.long()
try:
states = batch_states[0][idx]
states = states.long()
except Exception as e:
raise Exception(batch_states, idx)
return [states]
else:
return None
Expand All @@ -92,8 +99,12 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
setup["decoder"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=True)
setup["decoder_masked"] = DummyRNNTDecoder(vocab_size=2, blank_idx=2, blank_as_pad=False)
setup["joint"] = DummyRNNTJoint()
setup["encoder_output"] = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]], dtype=torch.float32).transpose(1, 2)
setup["encoded_lengths"] = torch.tensor([3])
# expected timesteps for max_symbols_per_step=5 are [[0, 0, 0, 0, 0, 1, 1], [1, 1, 1, 1, 1]],
# so we have both looped and regular iteration on the second frame
setup["encoder_output"] = torch.tensor(
[[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[0, 0, 1], [2, 0, 0], [0, 0, 0]]], dtype=torch.float32
).transpose(1, 2)
setup["encoded_lengths"] = torch.tensor([3, 2])
return setup


Expand Down Expand Up @@ -726,7 +737,6 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class):
decoder,
joint_net,
blank_index=len(token_list),
preserve_alignments=True,
preserve_frame_confidence=True,
max_symbols_per_step=max_symbols_per_step,
)
Expand Down