Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change of language detection functions #20

Merged
merged 12 commits into from
Jul 1, 2024
Merged
Binary file modified faster_whisper/assets/silero_vad.onnx
Binary file not shown.
152 changes: 79 additions & 73 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.vad_onset = 0.500
self.vad_offset = 0.363
self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin")
self.vad_model = None

(
self._preprocess_params,
Expand All @@ -149,7 +150,7 @@ def __init__(
else:
self.device = device

if self.use_vad_model:
if self.use_vad_model and self.vad_model is None:
self.vad_device = self.get_device(vad_device)

# load vad model and perform VAD preprocessing if needed
Expand Down Expand Up @@ -200,6 +201,7 @@ def preprocess(self, inputs):
]

inputs["features"] = features
del features
return inputs

def _forward(self, model_inputs, **forward_params):
Expand All @@ -210,16 +212,18 @@ def _forward(self, model_inputs, **forward_params):
segment_size = encoder_output.shape[1] * 2
segmented_outputs = []
for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs):
subsegments, seek, single_timestamp_ending = (
self.model._split_segments_by_timestamps(
tokenizer=self.tokenizer,
tokens=output["tokens"],
time_offset=segment_metadata["start_time"],
segment_size=segment_size,
segment_duration=segment_metadata["end_time"]
- segment_metadata["start_time"],
seek=0,
)
(
subsegments,
seek,
single_timestamp_ending,
) = self.model._split_segments_by_timestamps(
tokenizer=self.tokenizer,
tokens=output["tokens"],
time_offset=segment_metadata["start_time"],
segment_size=segment_size,
segment_duration=segment_metadata["end_time"]
- segment_metadata["start_time"],
seek=0,
)
segmented_outputs.append(
[
Expand Down Expand Up @@ -251,7 +255,6 @@ def _forward(self, model_inputs, **forward_params):
return {"output": segmented_outputs}

def __call__(self, inputs, options, batch_size=None, **kwargs):

if batch_size is None:
if self._batch_size is None:
batch_size = 1
Expand Down Expand Up @@ -302,7 +305,6 @@ def get_iterator(
forward_params=None,
postprocess_params=None,
):

def stack(items):
return {
"inputs": [x["inputs"] for x in items],
Expand Down Expand Up @@ -333,11 +335,14 @@ def get_language_and_tokenizer(
):
all_language_probs = None
language_probability = 1.0

if self.tokenizer is None:
if not language:
language, language_probability, all_language_probs = (
self.model.detect_language(audio)
)
(
language,
language_probability,
all_language_probs,
) = self.model.detect_language_function(audio)
task = task or "transcribe"
self.tokenizer = Tokenizer(
self.model.hf_tokenizer,
Expand Down Expand Up @@ -546,9 +551,12 @@ def transcribe(
"No vad segments found. Set 'use_vad_model' to True while loading the model"
)

language, language_probability, task, all_language_probs = (
self.get_language_and_tokenizer(audio, task, language)
)
(
language,
language_probability,
task,
all_language_probs,
) = self.get_language_and_tokenizer(audio, task, language)
batch_size = batch_size or self._batch_size

duration_after_vad = sum(
Expand Down Expand Up @@ -976,16 +984,27 @@ def transcribe(
or language_detection_segments < 1
):
language_detection_segments = 1
seek = 0
detected_language_info = {}
start_timestamp = (
float(clip_timestamps.split(",")[0])
if isinstance(clip_timestamps, str)
else clip_timestamps[0]
)
content_frames = (
features.shape[-1] - self.feature_extractor.nb_max_frames
)
while (
seek <= content_frames
and seek
< self.feature_extractor.nb_max_frames * language_detection_segments
):
seek = (
int(start_timestamp * self.frames_per_second)
if start_timestamp * self.frames_per_second < content_frames
else 0
)
end_frames = min(
seek
+ self.feature_extractor.nb_max_frames
* language_detection_segments,
content_frames,
)
detected_language_info = {}
while seek < end_frames:
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
Expand Down Expand Up @@ -1354,15 +1373,17 @@ def is_segment_anomaly(segment: Optional[dict]) -> bool:
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)

current_segments, seek, single_timestamp_ending = (
self._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=tokens,
time_offset=time_offset,
segment_size=segment_size,
segment_duration=segment_duration,
seek=seek,
)
(
current_segments,
seek,
single_timestamp_ending,
) = self._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=tokens,
time_offset=time_offset,
segment_size=segment_size,
segment_duration=segment_duration,
seek=seek,
)

if options.word_timestamps:
Expand All @@ -1375,7 +1396,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
options.append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)

if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
Expand Down Expand Up @@ -1432,7 +1452,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end

for segment in current_segments:
tokens = segment["tokens"]
text = tokenizer.decode(tokens)
Expand Down Expand Up @@ -1900,6 +1919,21 @@ def generate_segment_batched(

return encoder_output, output

def detect_language_function(self, audio: torch.Tensor):
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
:, : self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(segment)
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
self.logger.info(
f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..."
)
all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]]
return language, language_probability, all_language_probs

def detect_language(self, audio: torch.Tensor):
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
Expand Down Expand Up @@ -2123,38 +2157,6 @@ def key_func(language):
return {"language_code": None, "language_confidence": 1.0}


default_batched_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"max_new_tokens": None,
"clip_timestamps": "0",
"hallucination_silence_threshold": None,
"without_timestamps": True, # False for timings
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"log_prob_low_threshold": None,
"multilingual": False,
"output_language": "en",
"hotwords": None,
}


def restore_speech_timestamps(
segments: Iterable[Segment],
speech_chunks: List[dict],
Expand Down Expand Up @@ -2238,9 +2240,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) ->
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
# prepend it to the following word
following["word"] = previous["word"] + following["word"]
following["tokens"] = previous["tokens"] + following["tokens"]
if "tokens" in alignment[0].keys():
following["tokens"] = previous["tokens"] + following["tokens"]
previous["tokens"] = []
previous["word"] = ""
previous["tokens"] = []

else:
j = i
i -= 1
Expand All @@ -2254,9 +2258,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) ->
if not previous["word"].endswith(" ") and following["word"] in appended:
# append it to the previous word
previous["word"] = previous["word"] + following["word"]
previous["tokens"] = previous["tokens"] + following["tokens"]
if "tokens" in alignment[0].keys():
previous["tokens"] = previous["tokens"] + following["tokens"]
following["tokens"] = []
following["word"] = ""
following["tokens"] = []

else:
i = j
j += 1
39 changes: 13 additions & 26 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import bisect
import functools
import os
import warnings

from collections.abc import Callable
from typing import List, NamedTuple, Optional, Union
Expand Down Expand Up @@ -32,17 +31,13 @@ class VadOptions(NamedTuple):
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
before separating it
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
Values other than these may affect model performance!!
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
"""

threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
window_size_samples: int = 1024
speech_pad_ms: int = 400


Expand All @@ -68,15 +63,8 @@ def get_speech_timestamps(
min_speech_duration_ms = vad_options.min_speech_duration_ms
max_speech_duration_s = vad_options.max_speech_duration_s
min_silence_duration_ms = vad_options.min_silence_duration_ms
window_size_samples = vad_options.window_size_samples
window_size_samples = 512
speech_pad_ms = vad_options.speech_pad_ms

if window_size_samples not in [512, 1024, 1536]:
warnings.warn(
"Unusual window_size_samples! Supported window_size_samples:\n"
" - [512, 1024, 1536] for 16000 sampling_rate"
)

sampling_rate = 16000
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
Expand All @@ -91,14 +79,14 @@ def get_speech_timestamps(
audio_length_samples = len(audio)

model = get_vad_model()
state = model.get_initial_state(batch_size=1)
state, context = model.get_initial_states(batch_size=1)

speech_probs = []
for current_start_sample in range(0, audio_length_samples, window_size_samples):
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
if len(chunk) < window_size_samples:
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
speech_prob, state = model(chunk, state, sampling_rate)
speech_prob, state, context = model(chunk, state, context, sampling_rate)
speech_probs.append(speech_prob)

triggered = False
Expand Down Expand Up @@ -268,12 +256,12 @@ def __init__(self, path):
sess_options=opts,
)

def get_initial_state(self, batch_size: int):
h = np.zeros((2, batch_size, 64), dtype=np.float32)
c = np.zeros((2, batch_size, 64), dtype=np.float32)
return h, c
def get_initial_states(self, batch_size: int):
state = np.zeros((2, batch_size, 128), dtype=np.float32)
context = np.zeros((batch_size, 64), dtype=np.float32)
return state, context

def __call__(self, x, state, sr: int):
def __call__(self, x, state, context, sr: int):
if len(x.shape) == 1:
x = np.expand_dims(x, 0)
if len(x.shape) > 2:
Expand All @@ -283,19 +271,18 @@ def __call__(self, x, state, sr: int):
if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")

h, c = state
x = np.concatenate([context, x], axis=1)

ort_inputs = {
"input": x,
"h": h,
"c": c,
"state": state,
"sr": np.array(sr, dtype="int64"),
}

out, h, c = self.session.run(None, ort_inputs)
state = (h, c)
out, state = self.session.run(None, ort_inputs)
context = x[..., -64:]

return out, state
return out, state, context


# The code below is copied from whisper-x (https://github.com/m-bain/whisperX)
Expand Down
2 changes: 1 addition & 1 deletion faster_whisper/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version information."""

__version__ = "1.0.2"
__version__ = "1.0.3"