Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate transcribe_speech.py for short-form data: pre-sorting support #8564

Merged
merged 21 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from nemo.collections.asr.parts.utils.transcribe_utils import (
compute_output_filename,
prepare_audio_data,
restore_transcription_order,
setup_model,
transcribe_partial_audio,
write_transcription,
Expand Down Expand Up @@ -121,6 +122,7 @@ class TranscriptionConfig:
] = None # Used to select a single channel from multichannel audio, or use average across channels
audio_key: str = 'audio_filepath' # Used to override the default audio key in dataset_manifest
eval_config_yaml: Optional[str] = None # Path to a yaml file of config of evaluation
presort_manifest: bool = True # Significant inference speedup on short-form data due to padding reduction

# General configs
output_filename: Optional[str] = None
Expand Down Expand Up @@ -183,6 +185,7 @@ class TranscriptionConfig:

# key for groundtruth text in manifest
gt_text_attr_name: str = "text"
gt_lang_attr_name: str = "lang"

# Use model's transcribe() function instead of transcribe_partial_audio() by default
# Only use transcribe_partial_audio() when the audio is too long to fit in memory
Expand Down Expand Up @@ -317,6 +320,12 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
if isinstance(asr_model, EncDecMultiTaskModel):
# Special case for EncDecMultiTaskModel, where the input manifest is directly passed into the model's transcribe() function
partial_audio = False
if cfg.presort_manifest:
logging.warning(
"Pre-sorting manifest for EncDecMultiTaskModel is currently not supported; "
"please do it manually. We'll proceed with an unsorted manifest."
)
cfg.presort_manifest = False
if cfg.audio_dir is not None and not cfg.append_pred:
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
else:
Expand Down Expand Up @@ -369,17 +378,20 @@ def autocast(dtype=None):
decoder_type=cfg.decoder_type,
)
else:
transcriptions = asr_model.transcribe(
audio=filepaths,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
return_hypotheses=cfg.return_hypotheses,
channel_selector=cfg.channel_selector,
augmentor=augmentor,
)
override_cfg = asr_model.get_transcribe_config()
override_cfg.batch_size = cfg.batch_size
override_cfg.num_workers = cfg.num_workers
override_cfg.return_hypotheses = cfg.return_hypotheses
override_cfg.channel_selector = cfg.channel_selector
override_cfg.augmentor = augmentor
override_cfg.text_field = cfg.gt_text_attr_name
override_cfg.lang_field = cfg.gt_lang_attr_name
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)

if cfg.dataset_manifest is not None:
logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")
if cfg.presort_manifest:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
transcriptions = restore_transcription_order(cfg.dataset_manifest, transcriptions)
else:
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
logging.info(f"Finished transcribing {len(filepaths)} files !")
logging.info(f"Writing transcriptions into file: {cfg.output_filename}")
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class MultiTaskTranscriptionConfig(TranscribeConfig):
pnc: Optional[bool] = None
source_lang: Optional[str] = None
target_lang: Optional[str] = None
text_field: str = "text"
lang_field: str = "lang"

_internal: Optional[MultiTaskTranscriptionInternalConfig] = None

Expand Down Expand Up @@ -860,8 +862,8 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'use_lhotse': True,
'use_bucketing': False,
'drop_last': False,
'text_field': 'answer',
'lang_field': 'target_lang',
'text_field': config.get('text_field', 'answer'),
'lang_field': config.get('lang_field', 'target_lang'),
}

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TranscribeConfig:
verbose: bool = True

# Utility
partial_hypothesis: Optional[List[Any]] = False
partial_hypothesis: Optional[List[Any]] = None

_internal: Optional[InternalTranscribeConfig] = None

Expand Down Expand Up @@ -728,6 +728,8 @@ def _transcribe_input_manifest_processing(
'temp_dir': temp_dir,
'num_workers': get_value_from_transcription_config(trcfg, 'num_workers', 0),
'channel_selector': get_value_from_transcription_config(trcfg, 'channel_selector', None),
'text_field': get_value_from_transcription_config(trcfg, 'text_field', 'text'),
'lang_field': get_value_from_transcription_config(trcfg, 'lang_field', 'lang'),
}

augmentor = get_value_from_transcription_config(trcfg, 'augmentor', None)
Expand Down
30 changes: 28 additions & 2 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,19 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:

with open(cfg.dataset_manifest, 'r', encoding='utf_8') as f:
has_two_fields = []
for line in f:
item = json.loads(line)

if cfg.presort_manifest:
# We pre-load the manifest into CPU RAM below to check whether duration data is available
# and sort it descendingly for a significant inference speedup.
# Descending sort is preferred to present the largest possible mini-batch to the model
# at the very beginning to fail fast on OOM and pre-allocate max GPU memory needed.
items = [json.loads(l) for l in f]
if all("duration" in item for item in items):
items = sorted(items, reverse=True, key=lambda item: item["duration"])
else:
items = (json.loads(l) for l in f)

for item in items:
if "offset" in item and "duration" in item:
has_two_fields.append(True)
else:
Expand All @@ -299,6 +310,21 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
return filepaths, partial_audio


def restore_transcription_order(manifest_path: str, transcriptions: list) -> list:
nithinraok marked this conversation as resolved.
Show resolved Hide resolved
with open(manifest_path) as f:
items = [(idx, json.loads(l)) for idx, l in enumerate(f)]
new2old = [item[0] for item in sorted(items, reverse=True, key=lambda it: it[1]["duration"])]
is_list = isinstance(transcriptions[0], list)
if is_list:
transcriptions = list(zip(*transcriptions))
reordered = [None] * len(transcriptions)
for new, old in enumerate(new2old):
reordered[old] = transcriptions[new]
if is_list:
reordered = tuple(map(list, zip(*reordered)))
return reordered


def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig:
""" Compute filename of output manifest and update cfg"""
if cfg.output_filename is None:
Expand Down
Loading