Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>
  • Loading branch information
nithinraok committed Oct 18, 2024
1 parent 6ccf431 commit 762a29a
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 26 deletions.
6 changes: 4 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def transcribe(
"""
if timestamps:
raise NotImplementedError("Timestamps are not supported for this model yet.")

if override_config is None:
trcfg = MultiTaskTranscriptionConfig(
batch_size=batch_size,
Expand All @@ -501,7 +501,9 @@ def transcribe(
f"but got {type(override_config)}"
)
trcfg = override_config
import pdb; pdb.set_trace()
import pdb

pdb.set_trace()
return super().transcribe(audio=audio, override_config=trcfg)

def _setup_dataloader_from_config(self, config: Optional[Dict]):
Expand Down
10 changes: 7 additions & 3 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def transcribe(
"""
if timestamps:
logging.info("Timestamps requested, setting decoding timestamps to True")
return_hypotheses=True
return_hypotheses = True
with open_dict(self.cfg.decoding):
self.cfg.decoding.compute_timestamps = True
self.cfg.decoding.preserve_alignments = True
Expand Down Expand Up @@ -719,8 +719,12 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
del logits, logits_len

if trcfg.timestamps:
current_hypotheses = process_timestamp_outputs(current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
current_hypotheses = process_timestamp_outputs(
current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

hypotheses = []
if all_hyp is None:
Expand Down
10 changes: 7 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,9 @@ def change_vocabulary(

logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.")

def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True):
def change_decoding_strategy(
self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True
):
"""
Changes decoding strategy used during RNNT decoding process.
Args:
Expand Down Expand Up @@ -468,7 +470,9 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type

self.cur_decoder = "rnnt"
if verbose:
logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}")
logging.info(
f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}"
)

elif decoder_type == 'ctc':
if not hasattr(self, 'ctc_decoding'):
Expand Down Expand Up @@ -501,7 +505,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type
self.cur_decoder = "ctc"
if verbose:
logging.info(
f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}"
f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}"
)
else:
raise ValueError(f"decoder_type={decoder_type} is not supported. Supported values: [ctc,rnnt]")
Expand Down
16 changes: 11 additions & 5 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,12 @@ def transcribe(

if timestamps:
logging.info("Timestamps requested, setting decoding timestamps to True")
return_hypotheses=True
return_hypotheses = True
with open_dict(decoding_cfg):
decoding_cfg.decoding.compute_timestamps = True
decoding_cfg.decoding.preserve_alignments = True
self.change_decoding_strategy(decoding_cfg, self.cur_decoder, verbose=False)
else: # This is done to ensure the timestamps are not computed if not requested
else: # This is done to ensure the timestamps are not computed if not requested
with open_dict(decoding_cfg):
decoding_cfg.compute_timestamps = decoding_cfg.get('compute_timestamps', False)
decoding_cfg.preserve_alignments = decoding_cfg.get('preserve_alignments', False)
Expand Down Expand Up @@ -221,8 +221,12 @@ def _transcribe_output_processing(
# logits_list.append(logit[:elen])

if trcfg.timestamps:
hypotheses = process_timestamp_outputs(best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
hypotheses = process_timestamp_outputs(

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'hypotheses' is unnecessary as it is
redefined
before this value is used.
best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

del logits, encoded_len

Expand Down Expand Up @@ -318,7 +322,9 @@ def change_vocabulary(

logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.")

def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True):
def change_decoding_strategy(
self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True
):

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error as implicit returns always return None.
"""
Changes decoding strategy used during RNNT decoding process.
Expand Down
11 changes: 7 additions & 4 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def transcribe(
"""
if timestamps:
logging.info("Timestamps requested, setting decoding timestamps to True")
return_hypotheses=True
return_hypotheses = True
with open_dict(self.cfg.decoding):
self.cfg.decoding.compute_timestamps = True
self.cfg.decoding.preserve_alignments = True
Expand Down Expand Up @@ -941,8 +941,12 @@ def _transcribe_output_processing(
del encoded, encoded_len

if trcfg.timestamps:
hypotheses = process_timestamp_outputs(best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
all_hyp = process_timestamp_outputs(all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'])
hypotheses = process_timestamp_outputs(

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'hypotheses' is unnecessary as it is
redefined
before this value is used.
best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)
all_hyp = process_timestamp_outputs(
all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride']
)

hypotheses = []
all_hypotheses = []
Expand All @@ -955,7 +959,6 @@ def _transcribe_output_processing(

return (hypotheses, all_hypotheses)


def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader':
"""
Setup function for a temporary data loader which wraps the provided audio file.
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def __init__(
residual_panes = []
encoder_layers = []
self.dense_residual = False
self._subsampling_factor = 1
self._subsampling_factor = 1
for layer_idx, lcfg in enumerate(jasper):
dense_res = []
if lcfg.get('residual_dense', False):
Expand Down Expand Up @@ -182,7 +182,9 @@ def __init__(
)
)
feat_in = lcfg['filters']
self._subsampling_factor *= int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride'])
self._subsampling_factor *= (
int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride'])
)

self._feat_out = feat_in

Expand Down Expand Up @@ -235,6 +237,7 @@ def update_max_sequence_length(self, seq_length: int, device):
def subsampling_factor(self) -> int:
return self._subsampling_factor


class ParallelConvASREncoder(NeuralModule, Exportable):
"""
Convolutional encoder for ASR models with parallel blocks. CarneliNet can be implemented with this class.
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class TranscribeConfig:
num_workers: Optional[int] = None
channel_selector: ChannelSelectorType = None
augmentor: Optional[DictConfig] = None
timestamps: bool = False # returns timestamps for each word and segments if model supports punctuations
verbose: bool = True
timestamps: bool = False # returns timestamps for each word and segments if model supports punctuations
verbose: bool = True

# Utility
partial_hypothesis: Optional[List[Any]] = None
Expand Down
14 changes: 9 additions & 5 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def process_timestamp_outputs(outputs, subsampling_factor: int = 1, window_strid

if isinstance(outputs, rnnt_utils.Hypothesis):
outputs = [outputs]

if not isinstance(outputs[0], rnnt_utils.Hypothesis):
raise ValueError(f"Expected Hypothesis object, got {type(outputs[0])}")

Expand All @@ -604,19 +604,23 @@ def process_timestamp(timestamp, subsampling_factor, window_stride):
val['end'] = end

return timestamp

for idx, hyp in enumerate(outputs):
if not hasattr(hyp, 'timestep'):
raise ValueError(f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is enabled but got {hyp}")
raise ValueError(
f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is enabled but got {hyp}"
)
timestep = hyp.timestep
if 'word' in timestep:
outputs[idx].timestep['word'] = process_timestamp(timestep['word'], subsampling_factor, window_stride)
if 'char' in timestep:
outputs[idx].timestep['char'] = process_timestamp(timestep['char'], subsampling_factor, window_stride)
if 'segment' in timestep:
outputs[idx].timestep['segment'] = process_timestamp(timestep['segment'], subsampling_factor, window_stride)
outputs[idx].timestep['segment'] = process_timestamp(
timestep['segment'], subsampling_factor, window_stride
)
return outputs


class PunctuationCapitalization:
def __init__(self, punctuation_marks: str):
Expand Down

0 comments on commit 762a29a

Please sign in to comment.