Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix asr finetune #10508

Merged
merged 2 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ model:
# configs for huggingface load_dataset function
data_path: "librispeech_asr"
data_name: null # name for the specific dataset to load, e.g., 'en' for MCV datasets, but some datasets don't require this field.
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs.
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps and trainer.limit_train_batches, instead of trainer.max_epochs.

# keys for audio, sample_rate and transcription in the huggingface dataset, keys seperated by `.` for nested fields. See example at the bottom of this file.
audio_key: "audio.array"
Expand Down
6 changes: 3 additions & 3 deletions examples/asr/conf/vad/frame_vad_infer_postprocess.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ input_manifest: null # Path of json file of evaluation data. Audio files should
output_dir: null # Path to output directory where results will be stored
num_workers: 12
sample_rate: 16000
evaluate: False # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled
evaluate: false # whether to get AUROC and DERs, the manifest must contains groundtruth if enabled

prepare_manifest:
auto_split: True # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.
split_duration: 400 # try smaller number if you still have CUDA memory issue
auto_split: true # whether to automatically split manifest entry by split_duration to avoid potential CUDA out of memory issue.
split_duration: 400 # max length in seconds, try smaller number if you still have CUDA memory issue

vad:
model_path: "vad_multilingual_frame_marblenet" #.nemo local model path or pretrained model name or none
Expand Down
1 change: 1 addition & 0 deletions examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def get_base_model(trainer, cfg):
# restore model from cached model dir
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)

asr_model.set_trainer(trainer)
return asr_model


Expand Down
103 changes: 53 additions & 50 deletions nemo/collections/asr/parts/utils/vad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def autocast(enabled=None):

def prepare_manifest(config: dict) -> str:
"""
Perform VAD on long audio snippet might cause CUDA out of memory issue.
Perform VAD on long audio snippet might cause CUDA out of memory issue.
Automatically split manifest entry by split_duration to avoid the potential memory issue.
"""
if 'prepared_manifest_vad_input' in config and config['prepared_manifest_vad_input']:
Expand Down Expand Up @@ -132,7 +132,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list:
args_func:
label (str): label for audio snippet.y
split_duration (float): max duration of each audio clip (each line in json)
window_length_in_sec (float) : length of window for generating the frame. Used for taking care of joint.
window_length_in_sec (float) : length of window for generating the frame. Used for taking care of joint.
Returns:
res (list) : list of generated metadata line of json for file
"""
Expand Down Expand Up @@ -205,7 +205,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list:

def get_vad_stream_status(data: list) -> list:
"""
Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status.
Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status.
Used for concatenating to full audio file.
Args:
data (list): list of filepath of audio snippet
Expand Down Expand Up @@ -256,9 +256,9 @@ def generate_overlap_vad_seq(
out_dir: str = None,
) -> str:
"""
Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows.
Generate predictions with overlapping input windows/segments. Then a smoothing filter is applied to decide the label for a frame spanned by multiple windows.
Two common smoothing filters are supported: majority vote (median) and average (mean).
This function uses multiprocessing to speed up.
This function uses multiprocessing to speed up.
Args:
frame_pred_dir (str): Directory of frame prediction file to be processed.
smoothing_method (str): median or mean smoothing filter.
Expand Down Expand Up @@ -322,7 +322,7 @@ def generate_overlap_vad_seq_per_tensor(
"""
Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate prediction with overlapping input window/segments
See description in generate_overlap_vad_seq.
Use this for single instance pipeline.
Use this for single instance pipeline.
"""
# This function will be refactor for vectorization but this is okay for now

Expand Down Expand Up @@ -441,7 +441,7 @@ def filter_short_segments(segments: torch.Tensor, threshold: float) -> torch.Ten
Remove segments which duration is smaller than a threshold.
For example,
torch.Tensor([[0, 1.5], [1, 3.5], [4, 7]]) and threshold = 2.0
->
->
torch.Tensor([[1, 3.5], [4, 7]])
"""
return segments[segments[:, 1] - segments[:, 0] >= threshold]
Expand Down Expand Up @@ -482,20 +482,20 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
Binarize predictions to speech and non-speech
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
sequence (torch.Tensor) : A tensor of frame level predictions.
per_args:
onset (float): onset threshold for detecting the beginning and end of a speech
offset (float): offset threshold for detecting the end of a speech.
onset (float): onset threshold for detecting the beginning and end of a speech
offset (float): offset threshold for detecting the end of a speech.
pad_onset (float): adding durations before each speech segment
pad_offset (float): adding durations after each speech segment;
frame_length_in_sec (float): length of frame.
Returns:
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
"""
frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01)

Expand Down Expand Up @@ -545,9 +545,9 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te
def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor:
"""
Remove speech segments list in to_be_removed_segments from original_segments.
For example,
For example,
remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]),
->
->
torch.Tensor([[start1, end1],[start3, end3]])
"""
for y in to_be_removed_segments:
Expand All @@ -558,7 +558,7 @@ def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: tor
@torch.jit.script
def get_gap_segments(segments: torch.Tensor) -> torch.Tensor:
"""
Get the gap segments.
Get the gap segments.
For example,
torch.Tensor([[start1, end1], [start2, end2], [start3, end3]]) -> torch.Tensor([[end1, start2], [end2, start3]])
"""
Expand All @@ -568,22 +568,21 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor:

@torch.jit.script
def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor:

"""
Filter out short non_speech and speech segments.
Reference
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", InterSpeech 2015.
Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py
Args:
speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
per_args:
min_duration_on (float): threshold for small non_speech deletion
min_duration_off (float): threshold for short speech segment deletion
filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True.
filter_speech_first (float): Whether to perform short speech segment deletion first. Use 1.0 to represent True.
Returns:
speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
speech_segments(torch.Tensor): A tensor of filtered speech segment in torch.Tensor([[start1, end1], [start2, end2]]) format.
"""
if speech_segments.shape == torch.Size([0]):
return speech_segments
Expand Down Expand Up @@ -630,7 +629,7 @@ def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torc

def prepare_gen_segment_table(sequence: torch.Tensor, per_args: dict) -> Tuple[str, dict]:
"""
Preparing for generating segment table.
Preparing for generating segment table.
"""
out_dir = per_args.get('out_dir', None)

Expand Down Expand Up @@ -658,7 +657,7 @@ def prepare_gen_segment_table(sequence: torch.Tensor, per_args: dict) -> Tuple[s
def generate_vad_segment_table_per_tensor(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor:
"""
See description in generate_overlap_vad_seq.
Use this for single instance pipeline.
Use this for single instance pipeline.
"""
UNIT_FRAME_LEN = 0.01

Expand Down Expand Up @@ -721,7 +720,7 @@ def generate_vad_segment_table(
Args:
vad_pred_dir (str): directory of prediction files to be processed.
postprocessing_params (dict): dictionary of thresholds for prediction score. See details in binarization and filtering.
frame_length_in_sec (float): frame length.
frame_length_in_sec (float): frame length.
out_dir (str): output dir of generated table/csv file.
num_workers(float): number of process for multiprocessing
Returns:
Expand Down Expand Up @@ -1070,7 +1069,7 @@ def gen_pred_from_speech_segments(
speech_segments: torch.Tensor, prob: float, shift_length_in_sec: float = 0.01
) -> np.array:
"""
Generate prediction arrays like 000111000... from speech segments {[0,1][2,4]}
Generate prediction arrays like 000111000... from speech segments {[0,1][2,4]}
"""
pred = np.zeros(prob.shape)
speech_segments = [list(i) for i in speech_segments]
Expand All @@ -1086,7 +1085,7 @@ def gen_pred_from_speech_segments(
def extract_labels(path2ground_truth_label: str, time: list) -> list:
"""
Extract ground-truth label for given time period.
path2ground_truth_label (str): path of groundtruth RTTM file
path2ground_truth_label (str): path of groundtruth RTTM file
time (list) : a list of array representing time period.
"""

Expand Down Expand Up @@ -1273,7 +1272,6 @@ def stitch_segmented_asr_output(
def construct_manifest_eval(
input_manifest: str, stitched_output_manifest: str, aligned_vad_asr_output_manifest: str = "vad_asr_out.json"
) -> str:

"""
Generate aligned manifest for evaluation.
Because some pure noise samples might not appear in stitched_output_manifest.
Expand Down Expand Up @@ -1393,7 +1391,7 @@ def get_nonspeech_segments(
Args:
speech_segments (List[List[float]]): speech segment intervals loaded by load_speech_segments()
max_duration (Optional[float]): maximum duration of the audio, used to calculate the last silence segment
Returns:
nonspeech_segments (List[List[float]]): intervals of non-speech segments
"""
Expand Down Expand Up @@ -1483,8 +1481,8 @@ def plot_sample_from_rttm(

def align_labels_to_frames(probs, labels, threshold=0.2):
"""
Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms).
The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label
Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms).
The threshold 0.2 is not important, since the actual ratio will always be close to an integer unless using frame/label
lengths that are not multiples of each other (e.g., 15ms frame length and 20ms label length), which is not valid.
The value 0.2 here is just for easier unit testing.
Args:
Expand Down Expand Up @@ -1624,17 +1622,17 @@ def frame_vad_infer_load_manifest(cfg: DictConfig):
"""
Load manifest file and prepare label/rttm mapping
Args:
cfg: config file
cfg: DictConfig object
Returns:
manifest_orig (List[Dict]): original manifest data
manifest_orig (List[Dict]): original manifest data
key_labels_map (Dict): mapping from unique_audio_name to its labels
key_rttm_map (Dict): mapping from unique_audio_name to its rttm file
"""
unique_audio_names = set()
key_labels_map = {}
key_rttm_map = {}
manifest_orig = []
manifest_file = Path(cfg.dataset).absolute().as_posix()
manifest_file = Path(cfg.input_manifest).absolute().as_posix()
with open(manifest_file, 'r') as fin:
for line in fin.readlines():
entry = json.loads(line.strip())
Expand All @@ -1649,22 +1647,25 @@ def frame_vad_infer_load_manifest(cfg: DictConfig):

manifest_orig.append(entry)

# always prefer RTTM labels if exist
if "label" not in entry and ("rttm_filepath" in entry or "rttm_file" in entry):
if cfg.evaluate:
# always prefer RTTM labels if exist
rttm_key = "rttm_filepath" if "rttm_filepath" in entry else "rttm_file"
segments = load_speech_segments_from_rttm(entry[rttm_key])
label_str = get_frame_labels(
segments=segments,
frame_length=cfg.vad.parameters.shift_length_in_sec,
duration=entry['duration'],
offset=entry['offset'],
)
key_rttm_map[uniq_audio_name] = entry[rttm_key]
key_labels_map[uniq_audio_name] = [float(x) for x in label_str.split()]
elif entry.get("label", None) is not None:
key_labels_map[uniq_audio_name] = [float(x) for x in entry["label"].split()]
elif cfg.evaluate:
raise ValueError("Must have either `label` or `rttm_filepath` in manifest when evaluate=True")
rttm_file = entry.get(rttm_key, None)
if rttm_file:
rttm_file = get_full_path(audio_file=rttm_file, manifest_file=manifest_file)
segments = load_speech_segments_from_rttm(rttm_file)
label_str = get_frame_labels(
segments=segments,
frame_length=cfg.vad.parameters.shift_length_in_sec,
duration=entry['duration'],
offset=entry['offset'],
)
key_rttm_map[uniq_audio_name] = entry[rttm_key]
key_labels_map[uniq_audio_name] = [float(x) for x in label_str.split()]
elif entry.get("label", None) is not None:
key_labels_map[uniq_audio_name] = [float(x) for x in entry["label"].split()]
else:
raise ValueError("Must have either `label` or `rttm_filepath` in manifest when evaluate=True")

return manifest_orig, key_labels_map, key_rttm_map

Expand Down Expand Up @@ -1709,7 +1710,9 @@ def frame_vad_eval_detection_error(
groundtruth = key_labels_map[key]

reference, hypothesis = frame_vad_construct_pyannote_object_per_file(
prediction=key_pred_rttm_map[key], groundtruth=groundtruth, frame_length_in_sec=frame_length_in_sec,
prediction=key_pred_rttm_map[key],
groundtruth=groundtruth,
frame_length_in_sec=frame_length_in_sec,
)
metric(reference, hypothesis)

Expand Down
10 changes: 10 additions & 0 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,6 +1639,16 @@ def cfg(self, cfg):
if hasattr(self, '_hparams_initial') and 'cfg' in self._hparams_initial:
self._hparams_initial['cfg'] = OmegaConf.to_object(self._cfg)

@property
def hparams(self):
"""
Overwrite default hparams property to return the lastest model config.
Without this change, the hparams property would return the old config if there was a direct change to
self._cfg (e.g., in self.setup_optimization()) that was not done via `self.cfg = new_cfg`.
"""
self._set_hparams(OmegaConf.create({'cfg': self._cfg}))
return super().hparams

@property
def validation_step_outputs(self):
"""
Expand Down
Loading