Skip to content

Commit

Permalink
Dgalvez/fix greedy batch strategy name r2.0.0rc0 (#9243) (#9253)
Browse files Browse the repository at this point in the history
* Lazily warn about using greedy strategy instead of greedy_batch
strategy.

Previously, the warning would often run spuriously, since several
existing code paths simply call "change_decoding_strategy()" after
having first initialized a Module, rather than changing the config
before initializing the Module. This can be confusing.

The only problem I can see with this is that using logging inside a
forward() method might interfere with some compiler toolkits like
Torchscript or thunder.compile. Presumably it would be easy to add a
conditional statement to avoid this statement in a compiler context if
necessary.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Co-authored-by: Daniel Galvez <galv@users.noreply.github.com>
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
2 people authored and janekl committed Jun 12, 2024
1 parent 007a64e commit 3c3c592
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 50 deletions.
23 changes: 10 additions & 13 deletions nemo/collections/asr/parts/submodules/ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,31 @@ def __init__(self, decoding_cfg, blank_id: int):
self.batch_dim_index = self.cfg.get('batch_dim_index', 0)
self.word_seperator = self.cfg.get('word_seperator', ' ')

possible_strategies = ['greedy', 'greedy_batched', 'beam', 'pyctcdecode', 'flashlight']
possible_strategies = ['greedy', 'greedy_batch', 'beam', 'pyctcdecode', 'flashlight']
if self.cfg.strategy not in possible_strategies:
raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}")

# Update preserve alignments
if self.preserve_alignments is None:
if self.cfg.strategy in ['greedy', 'greedy_batched']:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False)
else:
self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False)

# Update compute timestamps
if self.compute_timestamps is None:
if self.cfg.strategy in ['greedy', 'greedy_batched']:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False)
elif self.cfg.strategy in ['beam']:
self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False)

# initialize confidence-related fields
self._init_confidence(self.cfg.get('confidence_cfg', None))

# Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batched`
# Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batch`
if (
not self.preserve_frame_confidence
and self.cfg.strategy not in ('greedy', 'greedy_batched')
and self.cfg.strategy not in ('greedy', 'greedy_batch')
and self.cfg.beam.get('preserve_frame_confidence', False)
):
raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`")
Expand All @@ -247,11 +247,6 @@ def __init__(self, decoding_cfg, blank_id: int):
self.compute_timestamps |= self.preserve_frame_confidence

if self.cfg.strategy == 'greedy':
logging.warning(
"CTC decoding strategy 'greedy' is slower than 'greedy_batched', which implements the same exact interface. Consider changing your strategy to 'greedy_batched' for a free performance improvement.",
mode=logging_mode.ONCE,
)

self.decoding = ctc_greedy_decoding.GreedyCTCInfer(
blank_id=self.blank_id,
preserve_alignments=self.preserve_alignments,
Expand All @@ -260,7 +255,7 @@ def __init__(self, decoding_cfg, blank_id: int):
confidence_method_cfg=self.confidence_method_cfg,
)

elif self.cfg.strategy == "greedy_batched":
elif self.cfg.strategy == "greedy_batch":
self.decoding = ctc_greedy_decoding.GreedyBatchedCTCInfer(
blank_id=self.blank_id,
preserve_alignments=self.preserve_alignments,
Expand Down Expand Up @@ -1023,7 +1018,9 @@ class CTCDecoding(AbstractCTCDecoding):
"""

def __init__(
self, decoding_cfg, vocabulary,
self,
decoding_cfg,
vocabulary,
):
blank_id = len(vocabulary)
self.vocabulary = vocabulary
Expand Down Expand Up @@ -1300,7 +1297,7 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:

@dataclass
class CTCDecodingConfig:
strategy: str = "greedy_batched"
strategy: str = "greedy_batch"

# preserve decoding alignments
preserve_alignments: Optional[bool] = None
Expand Down
4 changes: 2 additions & 2 deletions tests/collections/asr/decoding/test_ctc_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_batched_decoding_logprobs(
)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

cfg.strategy = 'greedy_batched'
cfg.strategy = 'greedy_batch'
batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

torch.manual_seed(1)
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_batched_decoding_logprobs(
def test_batched_decoding_labels(self, tmp_tokenizer, timestamps, length_is_none, labels_device, length_device):
cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps)
unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)
cfg.strategy = 'greedy_batched'
cfg.strategy = 'greedy_batch'
batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer)

torch.manual_seed(1)
Expand Down
7 changes: 5 additions & 2 deletions tests/collections/asr/test_asr_ctc_encoder_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_vocab_change(self, test_data_dir, asr_model):
def test_decoding_change(self, asr_model):
assert asr_model.decoding is not None
assert isinstance(asr_model.decoding, CTCBPEDecoding)
assert asr_model.decoding.cfg.strategy == "greedy_batched"
assert asr_model.decoding.cfg.strategy == "greedy_batch"
assert asr_model.decoding.preserve_alignments is False
assert asr_model.decoding.compute_timestamps is False

Expand Down Expand Up @@ -309,7 +309,10 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}

result = assert_dataclass_signature_match(
audio_to_text.AudioToBPEDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
audio_to_text.AudioToBPEDataset,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result

Expand Down
7 changes: 5 additions & 2 deletions tests/collections/asr/test_asr_ctcencdec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_vocab_change(self, asr_model):
def test_decoding_change(self, asr_model):
assert asr_model.decoding is not None
assert isinstance(asr_model.decoding, CTCDecoding)
assert asr_model.decoding.cfg.strategy == "greedy_batched"
assert asr_model.decoding.cfg.strategy == "greedy_batch"
assert asr_model.decoding.preserve_alignments is False
assert asr_model.decoding.compute_timestamps is False

Expand Down Expand Up @@ -279,7 +279,10 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self):
REMAP_ARGS = {'trim_silence': 'trim'}

result = assert_dataclass_signature_match(
audio_to_text.AudioToCharDataset, configs.ASRDatasetConfig, ignore_args=IGNORE_ARGS, remap_args=REMAP_ARGS,
audio_to_text.AudioToCharDataset,
configs.ASRDatasetConfig,
ignore_args=IGNORE_ARGS,
remap_args=REMAP_ARGS,
)
signatures_match, cls_subset, dataclass_subset = result

Expand Down
33 changes: 23 additions & 10 deletions tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def hybrid_asr_model(test_data_dir):

decoder = {
'_target_': 'nemo.collections.asr.modules.RNNTDecoder',
'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,},
'prednet': {
'pred_hidden': model_defaults['pred_hidden'],
'pred_rnn_layers': 1,
},
}

joint = {
'_target_': 'nemo.collections.asr.modules.RNNTJoint',
'jointnet': {'joint_hidden': 32, 'activation': 'relu',},
'jointnet': {
'joint_hidden': 32,
'activation': 'relu',
},
}

decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}}
Expand Down Expand Up @@ -111,7 +117,8 @@ def hybrid_asr_model(test_data_dir):

class TestEncDecHybridRNNTCTCBPEModel:
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.with_downloads()
@pytest.mark.unit
Expand All @@ -125,7 +132,8 @@ def test_constructor(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_forward(self, hybrid_asr_model):
Expand Down Expand Up @@ -160,7 +168,8 @@ def test_forward(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_save_restore_artifact(self, hybrid_asr_model):
Expand All @@ -178,7 +187,8 @@ def test_save_restore_artifact(self, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_save_restore_artifact_spe(self, hybrid_asr_model, test_data_dir):
Expand Down Expand Up @@ -224,7 +234,8 @@ def test_save_restore_artifact_agg(self, hybrid_asr_model, test_data_dir):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_vocab_change(self, test_data_dir, hybrid_asr_model):
Expand Down Expand Up @@ -255,7 +266,8 @@ def test_vocab_change(self, test_data_dir, hybrid_asr_model):

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_change(self, hybrid_asr_model):
Expand Down Expand Up @@ -297,7 +309,7 @@ def test_decoding_change(self, hybrid_asr_model):

assert hybrid_asr_model.ctc_decoding is not None
assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding)
assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched"
assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batch"
assert hybrid_asr_model.ctc_decoding.preserve_alignments is False
assert hybrid_asr_model.ctc_decoding.compute_timestamps is False

Expand All @@ -309,7 +321,8 @@ def test_decoding_change(self, hybrid_asr_model):
assert hybrid_asr_model.cur_decoder == "ctc"

@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.',
not NUMBA_RNNT_LOSS_AVAILABLE,
reason='RNNTLoss has not been compiled with appropriate numba version.',
)
@pytest.mark.unit
def test_decoding_type_change(self, hybrid_asr_model):
Expand Down
Loading

0 comments on commit 3c3c592

Please sign in to comment.