Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Label-Looping algorithm for RNN-T decoding by default #8831

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,9 @@ def score_hypothesis(

def initialize_state(self, y: torch.Tensor) -> List[torch.Tensor]:
batch = y.size(0)
state = [torch.ones([batch, self.context_size], dtype=torch.long, device=y.device) * self.blank_idx]
# state contains context_size - 1 elements for each utterance in batch,
# consistent with the state returned from StatelessNet.forward
state = [torch.ones([batch, self.context_size - 1], dtype=torch.long, device=y.device) * self.blank_idx]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hainan-xv, please, confirm that I broke nothing when fixing state for the Stateless decoder.
We need the state with the constant size (to allow replacements when we found the end of utterance), and forward returns the state of size [batch_size, context_size - 1]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, you can also use torch.full instead of torch.ones followed by multiplication. No need to change it though.

return state

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', False),
loop_labels=self.cfg.greedy.get('loop_labels', True),
use_cuda_graph_decoder=self.cfg.greedy.get('use_cuda_graph_decoder', False),
)
else:
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,9 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
- 'lin' for using the linear mapping.
- 'exp' for using exponential mapping with linear shift.
loop_labels: Switching between decoding algorithms. Both algorithms produce equivalent results.
loop_labels=True algorithm is faster (especially for large batches) but can use a bit more memory
loop_labels=True (default) algorithm is faster (especially for large batches) but can use a bit more memory
(negligible overhead compared to the amount of memory used by the encoder).
loop_labels=False (default) is an implementation of a traditional decoding algorithm, which iterates over
loop_labels=False is an implementation of a traditional decoding algorithm, which iterates over
frames (encoder output vectors), and in the inner loop, decodes labels for the current frame one by one,
stopping when <blank> is found.
loop_labels=True iterates over labels, on each step finding the next non-blank label
Expand All @@ -588,7 +588,7 @@ def __init__(
preserve_alignments: bool = False,
preserve_frame_confidence: bool = False,
confidence_method_cfg: Optional[DictConfig] = None,
loop_labels: bool = False,
loop_labels: bool = True,
use_cuda_graph_decoder: bool = False,
):
super().__init__(
Expand Down Expand Up @@ -2299,7 +2299,7 @@ class GreedyBatchedRNNTInferConfig:
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = field(default_factory=lambda: ConfidenceMethodConfig())
loop_labels: bool = False
loop_labels: bool = True
use_cuda_graph_decoder: bool = False

def __post_init__(self):
Expand Down
15 changes: 3 additions & 12 deletions nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,21 +283,12 @@ def loop_labels_torch(
became_inactive_mask = torch.empty_like(active_mask)

# loop while there are active utterances
first_step = True
while active_mask.any():
active_mask_prev.copy_(active_mask, non_blocking=True)
# stage 1: get decoder (prediction network) output
if first_step:
# start of the loop, SOS symbol is passed into prediction network, state is None
# we need to separate this for torch.jit
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size
)
first_step = False
else:
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size
)
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size
)
decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection

# stage 2: get joint output, iteratively seeking for non-blank labels
Expand Down
15 changes: 3 additions & 12 deletions nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,12 @@ def loop_labels_torch(
became_inactive_mask = torch.empty_like(active_mask)

# loop while there are active utterances
first_step = True
while active_mask.any():
active_mask_prev.copy_(active_mask, non_blocking=True)
# stage 1: get decoder (prediction network) output
if first_step:
# start of the loop, SOS symbol is passed into prediction network, state is None
# we need to separate this for torch.jit
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), None, add_sos=False, batch_size=batch_size
)
first_step = False
else:
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size
)
decoder_output, state, *_ = self.decoder.predict(
labels.unsqueeze(1), state, add_sos=False, batch_size=batch_size
)
decoder_output = self.joint.project_prednet(decoder_output) # do not recalculate joint projection

# stage 2: get joint output, iteratively seeking for non-blank labels
Expand Down
63 changes: 50 additions & 13 deletions tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Optional

import pytest
import torch
Expand Down Expand Up @@ -309,9 +310,14 @@ def test_BeamRNNTInferConfig(self):
)
@pytest.mark.unit
@pytest.mark.parametrize(
"greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer],
("greedy_class", "loop_labels"),
[
(greedy_decode.GreedyRNNTInfer, None),
(greedy_decode.GreedyBatchedRNNTInfer, True),
(greedy_decode.GreedyBatchedRNNTInfer, False),
],
)
def test_greedy_decoding(self, greedy_class):
def test_greedy_decoding(self, greedy_class, loop_labels: Optional[bool]):
token_list = [" ", "a", "b", "c"]
vocab_size = len(token_list)

Expand All @@ -330,7 +336,10 @@ def test_greedy_decoding(self, greedy_class):
decoder = RNNTDecoder(prednet_cfg, vocab_size)
joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list)

greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5)
additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels}
greedy = greedy_class(
decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs
)

# (B, D, T)
enc_out = torch.randn(1, encoder_output_size, 30)
Expand Down Expand Up @@ -381,17 +390,23 @@ def test_greedy_multi_decoding(self, greedy_class):
)
@pytest.mark.unit
@pytest.mark.parametrize(
"greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer],
("greedy_class", "loop_labels"),
[
(greedy_decode.GreedyRNNTInfer, None),
(greedy_decode.GreedyBatchedRNNTInfer, True),
(greedy_decode.GreedyBatchedRNNTInfer, False),
],
)
def test_greedy_decoding_stateless_decoder(self, greedy_class):
@pytest.mark.parametrize("context_size", [1, 2])
def test_greedy_decoding_stateless_decoder(self, greedy_class, loop_labels: Optional[bool], context_size: int):
token_list = [" ", "a", "b", "c"]
vocab_size = len(token_list)

encoder_output_size = 4
decoder_output_size = 4
joint_output_shape = 4

prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1}
prednet_cfg = {'pred_hidden': decoder_output_size, 'pred_rnn_layers': 1, 'context_size': context_size}
jointnet_cfg = {
'encoder_hidden': encoder_output_size,
'pred_hidden': decoder_output_size,
Expand All @@ -402,7 +417,10 @@ def test_greedy_decoding_stateless_decoder(self, greedy_class):
decoder = StatelessTransducerDecoder(prednet_cfg, vocab_size)
joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list)

greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5)
additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels}
greedy = greedy_class(
decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs
)

# (B, D, T)
enc_out = torch.randn(1, encoder_output_size, 30)
Expand Down Expand Up @@ -453,9 +471,14 @@ def test_greedy_multi_decoding_stateless_decoder(self, greedy_class):
)
@pytest.mark.unit
@pytest.mark.parametrize(
"greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer],
("greedy_class", "loop_labels"),
[
(greedy_decode.GreedyRNNTInfer, None),
(greedy_decode.GreedyBatchedRNNTInfer, True),
(greedy_decode.GreedyBatchedRNNTInfer, False),
],
)
def test_greedy_decoding_preserve_alignment(self, greedy_class):
def test_greedy_decoding_preserve_alignment(self, greedy_class, loop_labels: Optional[bool]):
token_list = [" ", "a", "b", "c"]
vocab_size = len(token_list)

Expand All @@ -474,8 +497,14 @@ def test_greedy_decoding_preserve_alignment(self, greedy_class):
decoder = RNNTDecoder(prednet_cfg, vocab_size)
joint_net = RNNTJoint(jointnet_cfg, vocab_size, vocabulary=token_list)

additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels}
greedy = greedy_class(
decoder, joint_net, blank_index=len(token_list) - 1, preserve_alignments=True, max_symbols_per_step=5
decoder,
joint_net,
blank_index=len(token_list) - 1,
preserve_alignments=True,
max_symbols_per_step=5,
**additional_decoding_kwargs,
)

# (B, D, T)
Expand Down Expand Up @@ -591,9 +620,14 @@ def test_beam_decoding_preserve_alignments(self, beam_config):
)
@pytest.mark.unit
@pytest.mark.parametrize(
"greedy_class", [greedy_decode.GreedyRNNTInfer, greedy_decode.GreedyBatchedRNNTInfer],
("greedy_class", "loop_labels"),
[
(greedy_decode.GreedyRNNTInfer, None),
(greedy_decode.GreedyBatchedRNNTInfer, True),
(greedy_decode.GreedyBatchedRNNTInfer, False),
],
)
def test_greedy_decoding_SampledRNNTJoint(self, greedy_class):
def test_greedy_decoding_SampledRNNTJoint(self, greedy_class, loop_labels: Optional[bool]):
token_list = [" ", "a", "b", "c"]
vocab_size = len(token_list)

Expand All @@ -612,7 +646,10 @@ def test_greedy_decoding_SampledRNNTJoint(self, greedy_class):
decoder = RNNTDecoder(prednet_cfg, vocab_size)
joint_net = SampledRNNTJoint(jointnet_cfg, vocab_size, n_samples=2, vocabulary=token_list)

greedy = greedy_class(decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5)
additional_decoding_kwargs = {} if loop_labels is None else {"loop_labels": loop_labels}
greedy = greedy_class(
decoder, joint_net, blank_index=len(token_list) - 1, max_symbols_per_step=5, **additional_decoding_kwargs
)

# (B, D, T)
enc_out = torch.randn(1, encoder_output_size, 30)
Expand Down
Loading
Loading