diff --git a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml index d759a809ec37..30c082aff91f 100644 --- a/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml +++ b/examples/asr/conf/vad/frame_vad_infer_postprocess.yaml @@ -1,6 +1,7 @@ name: &name "vad_inference_postprocessing" -dataset: null # Path of json file of evaluation data. Audio files should have unique names +input_manifest: null # Path of json file of evaluation data. Audio files should have unique names +output_dir: null # Path to output directory where results will be stored num_workers: 12 sample_rate: 16000 evaluate: False # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled diff --git a/examples/asr/speech_classification/frame_vad_infer.py b/examples/asr/speech_classification/frame_vad_infer.py index f716eb45bb64..594cc9637d73 100644 --- a/examples/asr/speech_classification/frame_vad_infer.py +++ b/examples/asr/speech_classification/frame_vad_infer.py @@ -21,7 +21,8 @@ ## Usage: python frame_vad_infer.py \ --config-path="../conf/vad" --config-name="frame_vad_infer_postprocess" \ - dataset= + input_manifest= \ + output_dir= The manifest json file should have the following format (each line is a Python dictionary): {"audio_filepath": "/path/to/audio_file1", "offset": 0, "duration": 10000} @@ -58,15 +59,25 @@ @hydra_runner(config_path="../conf/vad", config_name="frame_vad_infer_postprocess") def main(cfg): - if not cfg.dataset: + if not cfg.input_manifest: raise ValueError("You must input the path of json file of evaluation data") + output_dir = cfg.output_dir if cfg.output_dir else "frame_vad_outputs" + if os.path.exists(output_dir): + logging.warning( + f"Output directory {output_dir} already exists, use this only if you're tuning post-processing params." + ) + Path(output_dir).mkdir(parents=True, exist_ok=True) + + cfg.frame_out_dir = os.path.join(output_dir, "frame_preds") + cfg.smoothing_out_dir = os.path.join(output_dir, "smoothing_preds") + cfg.rttm_out_dir = os.path.join(output_dir, "rttm_preds") - # each line of dataset should be have different audio_filepath and unique name to simplify edge cases or conditions - logging.info(f"Loading manifest file {cfg.dataset}") + # each line of input_manifest should be have different audio_filepath and unique name to simplify edge cases or conditions + logging.info(f"Loading manifest file {cfg.input_manifest}") manifest_orig, key_labels_map, key_rttm_map = frame_vad_infer_load_manifest(cfg) # Prepare manifest for streaming VAD - manifest_vad_input = cfg.dataset + manifest_vad_input = cfg.input_manifest if cfg.prepare_manifest.auto_split: logging.info("Split long audio file to avoid CUDA memory issue") logging.debug("Try smaller split_duration if you still have CUDA memory issue") @@ -76,6 +87,7 @@ def main(cfg): 'split_duration': cfg.prepare_manifest.split_duration, 'num_workers': cfg.num_workers, 'prepared_manifest_vad_input': cfg.prepared_manifest_vad_input, + 'out_dir': output_dir, } manifest_vad_input = prepare_manifest(config) else: @@ -171,7 +183,7 @@ def main(cfg): key_pred_rttm_map[key] = entry['rttm_filepath'] if not cfg.out_manifest_filepath: - out_manifest_filepath = "manifest_vad_output.json" + out_manifest_filepath = os.path.join(output_dir, "manifest_vad_output.json") else: out_manifest_filepath = cfg.out_manifest_filepath write_manifest(out_manifest_filepath, manifest_new) diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index 432674225f5a..264e9cef99f8 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -35,6 +35,7 @@ from nemo.core.classes.common import PretrainedModelInfo, typecheck from nemo.core.neural_types import * from nemo.utils import logging, model_utils +from nemo.utils.cast_utils import cast_all __all__ = ['EncDecClassificationModel', 'EncDecRegressionModel'] @@ -851,6 +852,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.eval_loop_cnt = 0 self.ratio_threshold = cfg.get('ratio_threshold', 0.2) super().__init__(cfg=cfg, trainer=trainer) + self.decoder.output_types = self.output_types + self.decoder.output_types_for_export = self.output_types @classmethod def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: @@ -1148,3 +1151,43 @@ def get_metric_logits_labels(self, logits, labels, masks): labels = labels.gather(dim=0, index=idx.view(-1)) return logits, labels + + def forward_for_export( + self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ): + """ + This forward is used when we need to export the model to ONNX format. + Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models. + Args: + input: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps. + length: Vector of length B, that contains the individual lengths of the audio sequences. + cache_last_channel: Tensor of shape [N, B, T, H] which contains the cache for last channel layers + cache_last_time: Tensor of shape [N, B, H, T] which contains the cache for last time layers + N is the number of such layers which need caching, B is batch size, H is the hidden size of activations, + and T is the length of the cache + + Returns: + the output of the model + """ + enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) + if cache_last_channel is None: + encoder_output = enc_fun(audio_signal=input, length=length) + if isinstance(encoder_output, tuple): + encoder_output = encoder_output[0] + else: + encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( + audio_signal=input, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward) + ret = dec_fun(hidden_states=encoder_output.transpose(1, 2)) + if isinstance(ret, tuple): + ret = ret[0] + if cache_last_channel is not None: + ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len) + return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32) diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index e4f024d231ad..d8860a0c7cff 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -275,7 +275,9 @@ def generate_overlap_vad_seq( if out_dir: overlap_out_dir = out_dir else: - overlap_out_dir = frame_pred_dir + "/overlap_smoothing_output" + "_" + smoothing_method + "_" + str(overlap) + overlap_out_dir = os.path.join( + frame_pred_dir, "/overlap_smoothing_output" + "_" + smoothing_method + "_" + str(overlap) + ) if not os.path.exists(overlap_out_dir): os.mkdir(overlap_out_dir) @@ -732,7 +734,7 @@ def generate_vad_segment_table( if not out_dir: out_dir_name = "seg_output_" for key in postprocessing_params: - out_dir_name = out_dir_name + str(key) + str(postprocessing_params[key]) + "-" + out_dir_name = out_dir_name + "-" + str(key) + str(postprocessing_params[key]) out_dir = os.path.join(vad_pred_dir, out_dir_name) diff --git a/tests/collections/asr/test_asr_classification_model.py b/tests/collections/asr/test_asr_classification_model.py index 876bb6073a38..3888cb30204c 100644 --- a/tests/collections/asr/test_asr_classification_model.py +++ b/tests/collections/asr/test_asr_classification_model.py @@ -94,8 +94,8 @@ def frame_classification_model(): } decoder = { - 'cls': 'nemo.collections.asr.modules.ConvASRDecoderClassification', - 'params': {'feat_in': 32, 'num_classes': 5,}, + 'cls': 'nemo.collections.common.parts.MultiLayerPerceptron', + 'params': {'hidden_size': 32, 'num_classes': 5,}, } modelConfig = DictConfig(