Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faster-whisper (ctranslate2) as option for Whisper annotation workflow #1017

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
49 changes: 48 additions & 1 deletion lhotse/bin/modes/workflows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Optional, Union

import click
Expand Down Expand Up @@ -55,6 +56,36 @@ def workflows():
@click.option(
"-d", "--device", default="cpu", help="Device on which to run the inference."
)
@click.option(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change to:

@click.option(
    "--faster-whisper/--normal-whisper",
    default=True,
    help="If True, use faster-whisper's implementation based on CTranslate2.",
)

Otherwise it can't be turned off.

"--faster-whisper",
is_flag=True,
default=True,
help="If True, use faster-whisper's implementation based on CTranslate2.",
)
@click.option(
"--faster-whisper-use-vad",
is_flag=True,
default=True,
help="If True, use faster-whisper's built-in voice activity detection (SileroVAD)."
"Note: This requires onnxruntime to be installed.",
)
@click.option(
"--faster-whisper-add-alignments",
is_flag=True,
default=False,
help="If True, add word alignments using timestamps obtained using the cross-attention"
"pattern and dynamic time warping (Note: Less accurate than forced alignment).",
)
@click.option(
"--faster-whisper-compute-type",
default="float16",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
default="float16",
default="auto",

Otherwise it won't work on (some?) CPUs.

help="Type to use for computation. See https://opennmt.net/CTranslate2/quantization.html.",
)
@click.option(
"--faster-whisper-num-workers",
default=1,
help="Number of workers for parallelization across multiple GPUs.",
)
@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.")
@click.option(
"--force-nonoverlapping/--keep-overlapping",
Expand All @@ -72,6 +103,11 @@ def annotate_with_whisper(
device: str,
jobs: int,
force_nonoverlapping: bool,
faster_whisper: bool,
faster_whisper_use_vad: bool,
faster_whisper_compute_type: str,
faster_whisper_add_alignments: bool,
faster_whisper_num_workers: int,
):
"""
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
Expand All @@ -83,7 +119,18 @@ def annotate_with_whisper(
Note: this is an experimental feature of Lhotse, and is not guaranteed to yield
high quality of data.
"""
from lhotse import annotate_with_whisper as annotate_with_whisper_
if faster_whisper:
from lhotse import annotate_with_faster_whisper

annotate_with_whisper_ = partial(
annotate_with_faster_whisper,
compute_type=faster_whisper_compute_type,
num_workers=faster_whisper_num_workers,
vad_filter=faster_whisper_use_vad,
add_alignments=faster_whisper_add_alignments,
)
else:
from lhotse import annotate_with_whisper as annotate_with_whisper_

assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), (
"Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive "
Expand Down
1 change: 1 addition & 0 deletions lhotse/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .faster_whisper import annotate_with_faster_whisper
from .forced_alignment import align_with_torchaudio
from .meeting_simulation import *
from .whisper import annotate_with_whisper
244 changes: 244 additions & 0 deletions lhotse/workflows/faster_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import logging
import warnings
from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor
from typing import Any, Generator, List, Optional, Union

import numpy as np

from lhotse import (
CutSet,
MonoCut,
Recording,
RecordingSet,
SupervisionSegment,
add_durations,
)
from lhotse.qa import trim_supervisions_to_recordings
from lhotse.supervision import AlignmentItem
from lhotse.utils import fastcopy, is_module_available


def annotate_with_faster_whisper(
manifest: Union[RecordingSet, CutSet],
model_name: str = "base",
device: str = "cpu",
force_nonoverlapping: bool = False,
download_root: Optional[str] = None,
compute_type: str = "default",
num_workers: int = 1,
vad_filter: bool = True,
add_alignments: bool = False,
**decode_options,
) -> Generator[MonoCut, None, None]:
"""
Use OpenAI Whisper model via faster-whisper and CTranslate2 to annotate either
RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST. It will perform automatic segmentation,
transcription, and language identification. If the first argument is a CutSet, it will
overwrite the supervisions with the results of the inference.

Note: this is an experimental feature of Lhotse, and is not guaranteed to yield
high quality of data.

See the original repo for more details: https://github.com/guillaumekln/faster-whisper

:param manifest: a ``RecordingSet`` or ``CutSet`` object.
:param language: specify the language if known upfront, otherwise it will be auto-detected.
:param model_name: one of available Whisper variants (base, medium, large, etc.).
:param device: Where to run the inference (cpu, cuda, etc.).
:param force_nonoverlapping: if True, the Whisper segment time-stamps will be processed to make
sure they are non-overlapping.
:param download_root: Not supported by faster-whisper. Argument kept to maintain compatibility
with annotate_with_whisper. Faster-whisper uses
:param compute_type: Type to use for computation.
See https://opennmt.net/CTranslate2/quantization.html.
:param num_workers: Increasing the number of workers can improve the global throughput at the
cost of increased memory usage.
:param vad_filter: If True, use faster-whisper's built-in voice activity detection (SileroVAD).
:param add_alignments: if True, add word alignments using timestamps obtained using the cross-
attention pattern and dynamic time warping (Note: Less accurate than forced alignment).
:param decode_options: additional options to pass to the ``whisper.transcribe`` function.
:return: a generator of cuts (use ``CutSet.open_writer()`` to write them).
"""
assert is_module_available("faster_whisper"), (
"This function expects faster-whisper to be installed. "
"You can install it via 'pip install faster-whisper' "
"(see https://github.com/guillaumekln/faster-whisper/ for details)."
)
if not isinstance(manifest, RecordingSet) and not isinstance(manifest, CutSet):
raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.")
assert not vad_filter or is_module_available("onnxruntime"), (
"Use of VAD requires onnxruntime to be installed. "
"You can install it via 'pip install onnxruntime' "
"(see https://github.com/guillaumekln/faster-whisper/ for details)."
)
if vad_filter and add_alignments:
warnings.warn(
"Word timestamps can be very inaccurate when using VAD. We don't recommend using both "
f"options together. See https://github.com/guillaumekln/faster-whisper/issues/125."
)

model = _initialize_model(
model_name, device, compute_type, num_workers, download_root
)
with ThreadPoolExecutor(num_workers) as ex:
futures = []
for item in manifest:
futures.append(
ex.submit(
_process_single_manifest,
item,
model,
force_nonoverlapping,
vad_filter,
add_alignments,
**decode_options,
)
)
for item in as_completed(futures):
yield item.result()


def _initialize_model(
model_name: str,
device: str,
compute_type: str = "default",
num_workers: int = 1,
download_root: Optional[str] = None,
):
import torch
from faster_whisper import WhisperModel

# Parse device index
device, _, idx = device.partition(":")
if len(idx) > 0:
device_index = int(idx)
elif num_workers > 1 and device == "cuda":
# Limit num_workers to available GPUs
num_workers = min(num_workers, torch.cuda.device_count())
device_index = list(range(num_workers))
else:
device_index = 0
model = WhisperModel(
model_name,
device=device,
device_index=device_index,
compute_type=compute_type,
num_workers=num_workers,
download_root=download_root,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the change that enables this option is still not released in pip, I suggest a bit of workaround here, otherwise it cannot be ran:

+    opt_kwargs = {}
+    if download_root is not None:
+        opt_kwargs["download_root"] = download_root
     model = WhisperModel(
         model_name,
         device=device,
         device_index=device_index,
         compute_type=compute_type,
         num_workers=num_workers,
-        download_root=download_root,
+        **opt_kwargs,
     )
-    model.logger.setLevel(logging.WARNING)
+    if hasattr(model, "logger"):
+        model.logger.setLevel(logging.WARNING)

Note I also suggested a check for logger, on my installation model did not have the attribute logger defined.

)
model.logger.setLevel(logging.WARNING)
return model


def _process_single_manifest(
manifest: Union[Recording, MonoCut],
model,
force_nonoverlapping: bool,
vad_filter: bool,
add_alignments: bool = False,
**decode_options,
) -> MonoCut:
if isinstance(manifest, Recording):
if manifest.num_channels > 1:
logging.warning(
f"Skipping recording '{manifest.id}'. It has {manifest.num_channels} channels, "
f"but we currently only support mono input."
)
return []
recording_id = manifest.id
else:
recording_id = manifest.recording_id
audio = np.squeeze(manifest.resample(16000).load_audio())
segments, info = model.transcribe(
audio=audio,
word_timestamps=add_alignments,
vad_filter=vad_filter,
**decode_options,
)
# Create supervisions from segments while filtering out those with negative duration.
if add_alignments:
supervisions = [
SupervisionSegment(
id=f"{manifest.id}-{segment_id:06d}",
recording_id=recording_id,
start=round(segment.start, ndigits=8),
duration=add_durations(
segment.end, -segment.start, sampling_rate=16000
),
text=segment.text.strip(),
language=info.language,
).with_alignment(
"word",
[
AlignmentItem(
symbol=ws.word.strip(),
start=round(ws.start, ndigits=8),
duration=round(ws.end - ws.start, ndigits=8),
score=round(ws.probability, ndigits=3),
)
for ws in segment.words
],
)
for segment_id, segment in enumerate(segments)
if segment.end - segment.start > 0
]
else:
supervisions = [
SupervisionSegment(
id=f"{manifest.id}-{segment_id:06d}",
recording_id=recording_id,
start=round(segment.start, ndigits=8),
duration=add_durations(
segment.end, -segment.start, sampling_rate=16000
),
text=segment.text.strip(),
language=info.language,
)
for segment_id, segment in enumerate(segments)
if segment.end - segment.start > 0
]

if isinstance(manifest, Recording):
cut = manifest.to_cut()
if supervisions:
supervisions = (
_postprocess_timestamps(supervisions)
if force_nonoverlapping
else supervisions
)
cut.supervisions = list(
trim_supervisions_to_recordings(
recordings=manifest, supervisions=supervisions, verbose=False
)
)
else:
cut = fastcopy(
manifest,
supervisions=_postprocess_timestamps(supervisions)
if force_nonoverlapping
else supervisions,
)

return cut


def _postprocess_timestamps(supervisions: List[SupervisionSegment]):
"""
Whisper tends to have a lot of overlapping segments due to inaccurate end timestamps.
Under a strong assumption that the input speech is non-overlapping, we can fix that
by always truncating to the start timestamp of the next segment.
"""
from cytoolz import sliding_window

supervisions = sorted(supervisions, key=lambda s: s.start)

if len(supervisions) < 2:
return supervisions
out = []
for cur, nxt in sliding_window(2, supervisions):
if cur.end > nxt.start:
cur = cur.trim(end=nxt.start)
out.append(cur)
out.append(nxt)
return out