Skip to content

Commit

Permalink
Accelerate transcribe_speech.py for short-form data: pre-sorting su…
Browse files Browse the repository at this point in the history
…pport (#8564)

* POC using bucketing in transcribe_speech.py

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* extend to multi task aed

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* fixes for aed multi task text/lang field selectors

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* remove assert

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* fix

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* expose option for bucket buffer size

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* fixes, ctc support

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* support pre-sorting manifests in transcribe_speech.py

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* cleanup

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* reorder transcriptions back to original manifest order

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* remove bucketing entirely

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* code review changes

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* code review changes--amend

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* refactor text_field/lang_field passing

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Fix reordering bug; disable presorting for multi task for now

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Add support for presort + multi task model

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Code reviews

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

* Fix jenkins tests, add user-friendly error msg for canary

Signed-off-by: Piotr Żelasko <petezor@gmail.com>

---------

Signed-off-by: Piotr Żelasko <petezor@gmail.com>
  • Loading branch information
pzelasko committed Mar 7, 2024
1 parent 537dfa2 commit 5ee9efb
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 25 deletions.
36 changes: 27 additions & 9 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import contextlib
import glob
import json
import os
from dataclasses import dataclass, is_dataclass
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union

import pytorch_lightning as pl
Expand All @@ -32,6 +34,8 @@
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
read_and_maybe_sort_manifest,
restore_transcription_order,
setup_model,
transcribe_partial_audio,
write_transcription,
Expand Down Expand Up @@ -121,6 +125,7 @@ class TranscriptionConfig:
] = None # Used to select a single channel from multichannel audio, or use average across channels
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation
presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction

# General configs
output_filename: Optional[str] = None
Expand Down Expand Up @@ -183,6 +188,7 @@ class TranscriptionConfig:

# key for groundtruth text in manifest
gt_text_attr_name: str = "text"
gt_lang_attr_name: str = "lang"

# Use model's transcribe() function instead of transcribe_partial_audio() by default
# Only use transcribe_partial_audio() when the audio is too long to fit in memory
Expand Down Expand Up @@ -314,14 +320,21 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
else:
cfg.decoding = cfg.rnnt_decoding

remove_path_after_done = None
if isinstance(asr_model, EncDecMultiTaskModel):
# Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function
partial_audio = False
if cfg.audio_dir is not None and not cfg.append_pred:
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
else:
filepaths = cfg.dataset_manifest
assert cfg.dataset_manifest is not None
if cfg.presort_manifest:
with NamedTemporaryFile("w", suffix=".json", delete=False) as f:
for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=True):
print(json.dumps(item), file=f)
cfg.dataset_manifest = f.name
remove_path_after_done = f.name
filepaths = cfg.dataset_manifest
else:
# prepare audio filepaths and decide wether it's partial audio
filepaths, partial_audio = prepare_audio_data(cfg)
Expand Down Expand Up @@ -369,17 +382,22 @@ def autocast(dtype=None):
decoder_type=cfg.decoder_type,
)
else:
transcriptions = asr_model.transcribe(
audio=filepaths,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=cfg.return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
override_cfg = asr_model.get_transcribe_config()
override_cfg.batch_size = cfg.batch_size
override_cfg.num_workers = cfg.num_workers
override_cfg.return_hypotheses = cfg.return_hypotheses
override_cfg.channel_selector = cfg.channel_selector
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
if cfg.presort_manifest:
transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions)
if remove_path_after_done is not None:
os.unlink(remove_path_after_done)
else:
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Writing transcriptions into file: {cfg.output_filename}")
Expand Down
17 changes: 16 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Sequence

import omegaconf
import torch.utils.data
from lhotse import CutSet
from lhotse.cut import MixedCut, MonoCut
Expand Down Expand Up @@ -175,7 +177,16 @@ def canary_prompt(
language = [language]

if text is not None:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
try:
tokens = sum((tokenizer.text_to_ids(text_, lang_) for text_, lang_ in zip(text, language)), start=[])
except omegaconf.errors.KeyValidationError as e:
raise ProbablyIncorrectLanguageKeyError(
"We couldn't select the right tokenizer, which could be due to issues with reading "
"the language from the manifest. "
"If you're training, try setting lang_field='' to a different value (probably 'target_lang' or 'lang'). "
"If you're using model.transcribe() directly, please use override_config kwarg to set this. "
"If you're using transcribe_speech.py, use option gt_lang_attr_name='...' "
) from e
else:
tokens = None # create prompt for inference

Expand Down Expand Up @@ -231,3 +242,7 @@ def canary_prompt(
if tokens is not None:
prompted_tokens.append(tokenizer.eos_id)
return prompted_tokens


class ProbablyIncorrectLanguageKeyError(RuntimeError):
pass
6 changes: 4 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class MultiTaskTranscriptionConfig(TranscribeConfig):
pnc: Optional[bool] = None
source_lang: Optional[str] = None
target_lang: Optional[str] = None
text_field: str = "answer"
lang_field: str = "target_lang"

_internal: Optional[MultiTaskTranscriptionInternalConfig] = None

Expand Down Expand Up @@ -860,8 +862,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_lhotse': True,
'use_bucketing': False,
'drop_last': False,
'text_field': 'answer',
'lang_field': 'target_lang',
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TranscribeConfig:
verbose: bool = True

# Utility
partial_hypothesis: Optional[List[Any]] = False
partial_hypothesis: Optional[List[Any]] = None

_internal: Optional[InternalTranscribeConfig] = None

Expand Down Expand Up @@ -728,6 +728,8 @@ def _transcribe_input_manifest_processing(
'temp_dir': temp_dir,
'num_workers': get_value_from_transcription_config(trcfg, 'num_workers', 0),
'channel_selector': get_value_from_transcription_config(trcfg, 'channel_selector', None),
'text_field': get_value_from_transcription_config(trcfg, 'text_field', 'text'),
'lang_field': get_value_from_transcription_config(trcfg, 'lang_field', 'lang'),
}

augmentor = get_value_from_transcription_config(trcfg, 'augmentor', None)
Expand Down
47 changes: 35 additions & 12 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,23 +282,46 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
logging.error(f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!")
return None

with open(cfg.dataset_manifest, 'r', encoding='utf_8') as f:
has_two_fields = []
for line in f:
item = json.loads(line)
if "offset" in item and "duration" in item:
has_two_fields.append(True)
else:
has_two_fields.append(False)
audio_key = cfg.get('audio_key', 'audio_filepath')
audio_file = get_full_path(audio_file=item[audio_key], manifest_file=cfg.dataset_manifest)
filepaths.append(audio_file)
partial_audio = all(has_two_fields)
all_entries_have_offset_and_duration = True
for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=cfg.presort_manifest):
if not ("offset" in item and "duration" in item):
all_entries_have_offset_and_duration = False
audio_key = cfg.get('audio_key', 'audio_filepath')
audio_file = get_full_path(audio_file=item[audio_key], manifest_file=cfg.dataset_manifest)
filepaths.append(audio_file)
partial_audio = all_entries_have_offset_and_duration
logging.info(f"\nTranscribing {len(filepaths)} files...\n")

return filepaths, partial_audio


def read_and_maybe_sort_manifest(path: str, try_sort: bool = False) -> list[dict]:
"""Sorts the manifest if duration key is available for every utterance."""
with open(path) as f:
items = [json.loads(l) for l in f]
if try_sort and all("duration" in item for item in items):
items = sorted(items, reverse=True, key=lambda item: item["duration"])
return items


def restore_transcription_order(manifest_path: str, transcriptions: list) -> list:
with open(manifest_path) as f:
items = [(idx, json.loads(l)) for idx, l in enumerate(f)]
if not all("duration" in item[1] for item in items):
return transcriptions
new2old = [item[0] for item in sorted(items, reverse=True, key=lambda it: it[1]["duration"])]
del items # free up some memory
is_list = isinstance(transcriptions[0], list)
if is_list:
transcriptions = list(zip(*transcriptions))
reordered = [None] * len(transcriptions)
for new, old in enumerate(new2old):
reordered[old] = transcriptions[new]
if is_list:
reordered = tuple(map(list, zip(*reordered)))
return reordered


def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig:
""" Compute filename of output manifest and update cfg"""
if cfg.output_filename is None:
Expand Down

0 comments on commit 5ee9efb

Please sign in to comment.