diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 948760e68b30..5a7457f6379d 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -310,7 +310,9 @@ def score_hypothesis( def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: batch = y.size(0) - state = [torch.ones([batch, self.context_size], dtype=torch.long, device=y.device) * self.blank_idx] + # state contains context_size - 1 elements for each utterance in batch, + # consistent with the state returned from StatelessNet.forward + state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx] return state def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index ad71e5371f01..7a260f3c6c89 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -319,7 +319,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int): preserve_alignments=self.preserve_alignments, preserve_frame_confidence=self.preserve_frame_confidence, confidence_method_cfg=self.confidence_method_cfg, - loop_labels=self.cfg.greedy.get('loop_labels', False), + loop_labels=self.cfg.greedy.get('loop_labels', True), use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False), ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index d69ed1c41049..464dc46e358c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -568,9 +568,9 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. loop_labels: Switching between decoding algorithms. Both algorithms produce equivalent results. - loop_labels=True algorithm is faster (especially for large batches) but can use a bit more memory + loop_labels=True (default) algorithm is faster (especially for large batches) but can use a bit more memory (negligible overhead compared to the amount of memory used by the encoder). - loop_labels=False (default) is an implementation of a traditional decoding algorithm, which iterates over + loop_labels=False is an implementation of a traditional decoding algorithm, which iterates over frames (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one, stopping when is found. loop_labels=True iterates over labels, on each step finding the next non-blank label @@ -588,7 +588,7 @@ def __init__( preserve_alignments: bool = False, preserve_frame_confidence: bool = False, confidence_method_cfg: Optional[DictConfig] = None, - loop_labels: bool = False, + loop_labels: bool = True, use_cuda_graph_decoder: bool = False, ): super().__init__( @@ -2299,7 +2299,7 @@ class GreedyBatchedRNNTInferConfig: preserve_alignments: bool = False preserve_frame_confidence: bool = False confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig()) - loop_labels: bool = False + loop_labels: bool = True use_cuda_graph_decoder: bool = False def __post_init__(self): diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index 89b474e0f8ba..92cb8a36aeb5 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -283,21 +283,12 @@ def loop_labels_torch( became_inactive_mask = torch.empty_like(active_mask) # loop while there are active utterances - first_step = True while active_mask.any(): active_mask_prev.copy_(active_mask, non_blocking=True) # stage 1: get decoder (prediction network) output - if first_step: - # start of the loop, SOS symbol is passed into prediction network, state is None - # we need to separate this for torch.jit - decoder_output, state, *_ = self.decoder.predict( - labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size - ) - first_step = False - else: - decoder_output, state, *_ = self.decoder.predict( - labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size - ) + decoder_output, state, *_ = self.decoder.predict( + labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size + ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection # stage 2: get joint output, iteratively seeking for non-blank labels diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index e95ea48d15fe..c289ce06cdfa 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -294,21 +294,12 @@ def loop_labels_torch( became_inactive_mask = torch.empty_like(active_mask) # loop while there are active utterances - first_step = True while active_mask.any(): active_mask_prev.copy_(active_mask, non_blocking=True) # stage 1: get decoder (prediction network) output - if first_step: - # start of the loop, SOS symbol is passed into prediction network, state is None - # we need to separate this for torch.jit - decoder_output, state, *_ = self.decoder.predict( - labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size - ) - first_step = False - else: - decoder_output, state, *_ = self.decoder.predict( - labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size - ) + decoder_output, state, *_ = self.decoder.predict( + labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size + ) decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection # stage 2: get joint output, iteratively seeking for non-blank labels diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 60f807dc7b3e..85156bf9e2c5 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from typing import Optional import pytest import torch @@ -309,9 +310,14 @@ def test_BeamRNNTInferConfig(self): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding(self, greedy_class): + def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -330,7 +336,10 @@ def test_greedy_decoding(self, greedy_class): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) @@ -381,9 +390,15 @@ def test_greedy_multi_decoding(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_stateless_decoder(self, greedy_class): + @pytest.mark.parametrize("context_size", [1, 2]) + def test_greedy_decoding_stateless_decoder(self, greedy_class, loop_labels: Optional[bool], context_size: int): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -391,7 +406,7 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class): decoder_output_size = 4 joint_output_shape = 4 - prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1, 'context_size': context_size} jointnet_cfg = { 'encoder_hidden': encoder_output_size, 'pred_hidden': decoder_output_size, @@ -402,7 +417,10 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class): decoder = StatelessTransducerDecoder(prednet_cfg, vocab_size) joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) @@ -453,9 +471,14 @@ def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_preserve_alignment(self, greedy_class): + def test_greedy_decoding_preserve_alignment(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -474,8 +497,14 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} greedy = greedy_class( - decoder, joint_net, blank_index=len(token_list) - 1, preserve_alignments=True, max_symbols_per_step=5 + decoder, + joint_net, + blank_index=len(token_list) - 1, + preserve_alignments=True, + max_symbols_per_step=5, + **additional_decoding_kwargs, ) # (B, D, T) @@ -591,9 +620,14 @@ def test_beam_decoding_preserve_alignments(self, beam_config): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_SampledRNNTJoint(self, greedy_class): + def test_greedy_decoding_SampledRNNTJoint(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -612,7 +646,10 @@ def test_greedy_decoding_SampledRNNTJoint(self, greedy_class): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d7c47adce1ad..d5ab0054ff87 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -73,7 +73,7 @@ def predict( return ( output, [ - torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].exand( + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :].expand( [1, batch_size, -1] ) ], @@ -90,22 +90,25 @@ def predict( ], ) - def initialize_state(self, y: torch.Tensor) -> Optional[List[torch.Tensor]]: - return None + def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]: + batch_size = y.shape[0] + # NB: .clone is necessary after .expand, since the decoding algorithm manipulates the state + # (replacing elements), and this requires the state to be a real full tensor + # (not an expanded view, in which different elements can refer to the same memory location) + return [ + torch.tensor([0] * self.vocab_size + [1], dtype=torch.float32)[None, None, :] + .expand([1, batch_size, -1]) + .clone() + ] def score_hypothesis( self, hypothesis: Hypothesis, cache: Dict[Tuple[int], Any] ) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]: return torch.tensor(), [torch.tensor()], torch.tensor() - def batch_select_state( - self, batch_states: Optional[List[torch.Tensor]], idx: int - ) -> Optional[List[List[torch.Tensor]]]: - if batch_states is not None: - states = [batch_states[0][:, idx]] - return [states] - else: - return None + def batch_select_state(self, batch_states: List[torch.Tensor], idx: int) -> Optional[List[List[torch.Tensor]]]: + states = [batch_states[0][:, idx]] + return [states] def batch_copy_states( self, @@ -126,6 +129,22 @@ def mask_select_states( return None return [states[0][:, mask]] + @classmethod + def batch_replace_states_mask( + cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor, + ): + """Replace states in dst_states with states from src_states using the mask""" + for src_substate, dst_substate in zip(src_states, dst_states): + torch.where(mask.unsqueeze(0).unsqueeze(-1), src_substate, dst_substate, out=dst_substate) + + @classmethod + def batch_split_states(cls, batch_states: list[torch.Tensor]) -> list[list[torch.Tensor]]: + """ + Split states into a list of states. + Useful for splitting the final state for converting results of the decoding algorithm to Hypothesis class. + """ + return [sub_state.split(1, dim=1) for sub_state in batch_states] + class DummyRNNTJoint(AbstractRNNTJoint): def __init__(self, num_outputs: int): super().__init__() @@ -621,9 +640,15 @@ def test_greedy_multi_decoding(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_stateless_decoder(self, greedy_class): + @pytest.mark.parametrize("context_size", [1, 2]) + def test_greedy_decoding_stateless_decoder(self, greedy_class, loop_labels: Optional[bool], context_size: int): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -631,7 +656,7 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class): decoder_output_size = 4 joint_output_shape = 4 - prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1} + prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1, 'context_size': context_size} jointnet_cfg = { 'encoder_hidden': encoder_output_size, 'pred_hidden': decoder_output_size, @@ -642,8 +667,14 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class): decoder = StatelessTransducerDecoder(prednet_cfg, vocab_size) for joint_type in [RNNTJoint, HATJoint]: joint_net = joint_type(jointnet_cfg, vocab_size, vocabulary=token_list) - - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, + joint_net, + blank_index=len(token_list) - 1, + max_symbols_per_step=5, + **additional_decoding_kwargs, + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30) @@ -696,9 +727,14 @@ def test_greedy_multi_decoding_stateless_decoder(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_preserve_alignment(self, greedy_class): + def test_greedy_decoding_preserve_alignment(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -719,13 +755,14 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class): max_symbols_per_step = 5 for joint_type in [RNNTJoint, HATJoint]: joint_net = joint_type(jointnet_cfg, vocab_size, vocabulary=token_list) - + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} greedy = greedy_class( decoder, joint_net, blank_index=len(token_list), preserve_alignments=True, max_symbols_per_step=max_symbols_per_step, + **additional_decoding_kwargs, ) # (B, D, T) @@ -760,9 +797,14 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_preserve_frame_confidence(self, greedy_class): + def test_greedy_decoding_preserve_frame_confidence(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -784,12 +826,14 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class): for joint_type in [RNNTJoint, HATJoint]: joint_net = joint_type(jointnet_cfg, vocab_size, vocabulary=token_list) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} greedy = greedy_class( decoder, joint_net, blank_index=len(token_list), preserve_frame_confidence=True, max_symbols_per_step=max_symbols_per_step, + **additional_decoding_kwargs, ) # (B, D, T) @@ -827,10 +871,17 @@ def test_greedy_decoding_preserve_frame_confidence(self, greedy_class): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) @pytest.mark.parametrize("max_symbols_per_step", [1, 5]) - def test_greedy_decoding_max_symbols_alignment(self, max_symbols_setup, greedy_class, max_symbols_per_step): + def test_greedy_decoding_max_symbols_alignment( + self, max_symbols_setup, greedy_class, max_symbols_per_step: int, loop_labels: Optional[bool] + ): decoders = [max_symbols_setup["decoder"]] if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: decoders.append(max_symbols_setup["decoder_masked"]) @@ -839,12 +890,14 @@ def test_greedy_decoding_max_symbols_alignment(self, max_symbols_setup, greedy_c encoded_lengths = max_symbols_setup["encoded_lengths"] for decoder in decoders: + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} greedy = greedy_class( decoder_model=decoder, joint_model=joint, blank_index=decoder.blank_idx, max_symbols_per_step=max_symbols_per_step, preserve_alignments=True, + **additional_decoding_kwargs, ) with torch.no_grad(): @@ -869,10 +922,17 @@ def test_greedy_decoding_max_symbols_alignment(self, max_symbols_setup, greedy_c ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) @pytest.mark.parametrize("max_symbols_per_step", [-1, 0]) - def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_class, max_symbols_per_step): + def test_greedy_decoding_max_symbols_confidence_incorrect_max_symbols( + self, max_symbols_setup, greedy_class, max_symbols_per_step: int, loop_labels: Optional[bool] + ): """Test ValueError for max_symbols_per_step <= 0""" decoders = [max_symbols_setup["decoder"]] if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: @@ -880,6 +940,7 @@ def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_ joint = max_symbols_setup["joint"] for decoder in decoders: + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} with pytest.raises(ValueError): _ = greedy_class( decoder_model=decoder, @@ -887,6 +948,7 @@ def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_ blank_index=decoder.blank_idx, max_symbols_per_step=max_symbols_per_step, preserve_frame_confidence=True, + **additional_decoding_kwargs, ) @pytest.mark.skipif( @@ -894,10 +956,17 @@ def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_ ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) @pytest.mark.parametrize("max_symbols_per_step", [1, 5]) - def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_class, max_symbols_per_step): + def test_greedy_decoding_max_symbols_confidence( + self, max_symbols_setup, greedy_class, max_symbols_per_step: int, loop_labels: Optional[bool] + ): decoders = [max_symbols_setup["decoder"]] if greedy_class is greedy_decode.GreedyBatchedRNNTInfer: decoders.append(max_symbols_setup["decoder_masked"]) @@ -906,12 +975,14 @@ def test_greedy_decoding_max_symbols_confidence(self, max_symbols_setup, greedy_ encoded_lengths = max_symbols_setup["encoded_lengths"] for decoder in decoders: + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} greedy = greedy_class( decoder_model=decoder, joint_model=joint, blank_index=decoder.blank_idx, max_symbols_per_step=max_symbols_per_step, preserve_frame_confidence=True, + **additional_decoding_kwargs, ) with torch.no_grad(): @@ -1035,9 +1106,14 @@ def test_beam_decoding_preserve_alignments(self, beam_config): ) @pytest.mark.unit @pytest.mark.parametrize( - "greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer], + ("greedy_class", "loop_labels"), + [ + (greedy_decode.GreedyRNNTInfer, None), + (greedy_decode.GreedyBatchedRNNTInfer, True), + (greedy_decode.GreedyBatchedRNNTInfer, False), + ], ) - def test_greedy_decoding_SampledRNNTJoint(self, greedy_class): + def test_greedy_decoding_SampledRNNTJoint(self, greedy_class, loop_labels: Optional[bool]): token_list = [" ", "a", "b", "c"] vocab_size = len(token_list) @@ -1056,7 +1132,10 @@ def test_greedy_decoding_SampledRNNTJoint(self, greedy_class): decoder = RNNTDecoder(prednet_cfg, vocab_size) joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list) - greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5) + additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels} + greedy = greedy_class( + decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs + ) # (B, D, T) enc_out = torch.randn(1, encoder_output_size, 30)