Skip to content

Commit

Permalink
Merge pull request #9 from mobiusml/fw_pr
Browse files Browse the repository at this point in the history
added 'use_vad_model' to better handle vad segments
  • Loading branch information
Jiltseb authored Apr 12, 2024
2 parents 538366b + 0e8fa00 commit 0d6c62e
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class BatchedInferencePipeline(Pipeline):
def __init__(
self,
model,
use_vad_model: bool = True,
options: Optional[NamedTuple] = None,
tokenizer=None,
device: Union[int, str, "torch.device"] = -1,
Expand All @@ -138,6 +139,7 @@ def __init__(
self.preset_language = language
self._batch_size = kwargs.pop("batch_size", None)
self._num_workers = 1
self.use_vad_model = use_vad_model
self.vad_onset = 0.500
self.vad_offset = 0.363
self.vad_model_url = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
Expand All @@ -161,10 +163,11 @@ def __init__(
else:
self.device = device

# load vad model and perform VAD preprocessing if needed
self.vad_model = self.load_vad_model(
vad_onset=self.vad_onset, vad_offset=self.vad_offset
)
if self.use_vad_model:
# load vad model and perform VAD preprocessing if needed
self.vad_model = self.load_vad_model(
vad_onset=self.vad_onset, vad_offset=self.vad_offset
)
self.chunk_size = 30 # VAD merging size

super(Pipeline, self).__init__()
Expand Down Expand Up @@ -483,15 +486,18 @@ def transcribe(

# if no segment split is provided, use vad_model and generate segments
if not vad_segments:
vad_segments = self.vad_model(
{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}
)
vad_segments = merge_chunks(
vad_segments,
self.chunk_size,
onset=self.vad_onset,
offset=self.vad_offset,
)
if self.use_vad_model:
vad_segments = self.vad_model(
{"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000}
)
vad_segments = merge_chunks(
vad_segments,
self.chunk_size,
onset=self.vad_onset,
offset=self.vad_offset,
)
else:
raise RuntimeError("No vad segments found. Set 'use_vad_model' to True while loading the model")

language, language_probability, task = self.get_language_and_tokenizer(
audio, task, language
Expand Down

0 comments on commit 0d6c62e

Please sign in to comment.