diff --git a/whisper/timing.py b/whisper/timing.py index b695ead0..e7604fa7 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -2,7 +2,7 @@ import subprocess import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional, Callable import numba import numpy as np @@ -284,6 +284,7 @@ def add_word_timestamps( prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", last_speech_timestamp: float, + word_stream_callback: Optional[Callable] = None, **kwargs, ): if len(segments) == 0: @@ -327,6 +328,8 @@ def add_word_timestamps( timing = alignment[word_index] if timing.word: + if word_stream_callback is not None: + word_stream_callback(timing) words.append( dict( word=timing.word, diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 1c075a20..df063cb5 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,7 +2,7 @@ import os import traceback import warnings -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -40,6 +40,7 @@ def transcribe( audio: Union[str, np.ndarray, torch.Tensor], *, verbose: Optional[bool] = None, + word_stream_callback: Optional[Callable] = None, temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, @@ -68,6 +69,9 @@ def transcribe( Whether to display the text being decoded to the console. If True, displays all the details, If False, displays minimal details. If None, does not display anything + word_stream_callback: Callable + Function that receives ready words as the other voice chunks are in progress. + temperature: Union[float, Tuple[float, ...]] Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. @@ -392,6 +396,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp, + word_stream_callback=word_stream_callback ) if not single_timestamp_ending: