diff --git a/.gitignore b/.gitignore index 540c1326..f32bb39e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ whisperx.egg-info/ **/__pycache__/ .ipynb_checkpoints + +# Virtual environment +venv/ +env/ +.venv/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 865abd1f..433adfc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch>=2 torchaudio>=2 -faster-whisper==1.0.0 +faster-whisper==1.0.3 transformers pandas setuptools>=65 diff --git a/whisperx/asr.py b/whisperx/asr.py index 0ccaf92b..43456129 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -13,6 +13,7 @@ from .vad import load_vad_model, merge_chunks from .types import TranscriptionResult, SingleSegment + def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [] for i in range(tokenizer.eot): @@ -22,13 +23,20 @@ def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens.append(i) return numeral_symbol_tokens + class WhisperModel(faster_whisper.WhisperModel): - ''' + """ FasterWhisperModel provides batched inference for faster-whisper. Currently only works in non-timestamp mode and fixed prompt for all samples in batch. - ''' + """ - def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisper.tokenizer.Tokenizer, options: faster_whisper.transcribe.TranscriptionOptions, encoder_output = None): + def generate_segment_batched( + self, + features: np.ndarray, + tokenizer: faster_whisper.tokenizer.Tokenizer, + options: faster_whisper.transcribe.TranscriptionOptions, + encoder_output=None, + ): batch_size = features.shape[0] all_tokens = [] prompt_reset_since = 0 @@ -51,15 +59,15 @@ def generate_segment_batched(self, features: np.ndarray, tokenizer: faster_whisp ) result = self.model.generate( - encoder_output, - [prompt] * batch_size, - beam_size=options.beam_size, - patience=options.patience, - length_penalty=options.length_penalty, - max_length=self.max_length, - suppress_blank=options.suppress_blank, - suppress_tokens=options.suppress_tokens, - ) + encoder_output, + [prompt] * batch_size, + beam_size=options.beam_size, + patience=options.patience, + length_penalty=options.length_penalty, + max_length=self.max_length, + suppress_blank=options.suppress_blank, + suppress_tokens=options.suppress_tokens, + ) tokens_batch = [x.sequences_ids[0] for x in result] @@ -85,26 +93,28 @@ def encode(self, features: np.ndarray) -> ctranslate2.StorageView: return self.model.encode(features, to_cpu=to_cpu) + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. """ + # TODO: # - add support for timestamp mode # - add support for custom inference kwargs def __init__( - self, - model, - vad, - vad_params: dict, - options : NamedTuple, - tokenizer=None, - device: Union[int, str, "torch.device"] = -1, - framework = "pt", - language : Optional[str] = None, - suppress_numerals: bool = False, - **kwargs + self, + model, + vad, + vad_params: dict, + options: NamedTuple, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + framework="pt", + language: Optional[str] = None, + suppress_numerals: bool = False, + **kwargs, ): self.model = model self.tokenizer = tokenizer @@ -113,7 +123,9 @@ def __init__( self.suppress_numerals = suppress_numerals self._batch_size = kwargs.pop("batch_size", None) self._num_workers = 1 - self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs) + self._preprocess_params, self._forward_params, self._postprocess_params = ( + self._sanitize_parameters(**kwargs) + ) self.call_count = 0 self.framework = framework if self.framework == "pt": @@ -139,24 +151,32 @@ def _sanitize_parameters(self, **kwargs): return preprocess_kwargs, {}, {} def preprocess(self, audio): - audio = audio['inputs'] + audio = audio["inputs"] model_n_mels = self.model.feat_kwargs.get("feature_size") features = log_mel_spectrogram( audio, n_mels=model_n_mels if model_n_mels is not None else 80, padding=N_SAMPLES - audio.shape[0], ) - return {'inputs': features} + return {"inputs": features} def _forward(self, model_inputs): - outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options) - return {'text': outputs} + outputs = self.model.generate_segment_batched( + model_inputs["inputs"], self.tokenizer, self.options + ) + return {"text": outputs} def postprocess(self, model_outputs): return model_outputs def get_iterator( - self, inputs, num_workers: int, batch_size: int, preprocess_params, forward_params, postprocess_params + self, + inputs, + num_workers: int, + batch_size: int, + preprocess_params, + forward_params, + postprocess_params, ): dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) if "TOKENIZERS_PARALLELISM" not in os.environ: @@ -164,26 +184,46 @@ def get_iterator( # TODO hack by collating feature_extractor and image_processor def stack(items): - return {'inputs': torch.stack([x['inputs'] for x in items])} - dataloader = torch.utils.data.DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack) - model_iterator = PipelineIterator(dataloader, self.forward, forward_params, loader_batch_size=batch_size) - final_iterator = PipelineIterator(model_iterator, self.postprocess, postprocess_params) + return {"inputs": torch.stack([x["inputs"] for x in items])} + + dataloader = torch.utils.data.DataLoader( + dataset, num_workers=num_workers, batch_size=batch_size, collate_fn=stack + ) + model_iterator = PipelineIterator( + dataloader, self.forward, forward_params, loader_batch_size=batch_size + ) + final_iterator = PipelineIterator( + model_iterator, self.postprocess, postprocess_params + ) return final_iterator def transcribe( - self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False + self, + audio: Union[str, np.ndarray], + batch_size=None, + num_workers=0, + language=None, + task=None, + chunk_size=30, + print_progress=False, + combined_progress=False, ) -> TranscriptionResult: if isinstance(audio, str): audio = load_audio(audio) def data(audio, segments): for seg in segments: - f1 = int(seg['start'] * SAMPLE_RATE) - f2 = int(seg['end'] * SAMPLE_RATE) + f1 = int(seg["start"] * SAMPLE_RATE) + f2 = int(seg["end"] * SAMPLE_RATE) # print(f2-f1) - yield {'inputs': audio[f1:f2]} + yield {"inputs": audio[f1:f2]} - vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE}) + vad_segments = self.vad_model( + { + "waveform": torch.from_numpy(audio).unsqueeze(0), + "sample_rate": SAMPLE_RATE, + } + ) vad_segments = merge_chunks( vad_segments, chunk_size, @@ -193,17 +233,23 @@ def data(audio, segments): if self.tokenizer is None: language = language or self.detect_language(audio) task = task or "transcribe" - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, - self.model.model.is_multilingual, task=task, - language=language) + self.tokenizer = faster_whisper.tokenizer.Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) else: language = language or self.tokenizer.language_code task = task or self.tokenizer.task if task != self.tokenizer.task or language != self.tokenizer.language_code: - self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer, - self.model.model.is_multilingual, task=task, - language=language) - + self.tokenizer = faster_whisper.tokenizer.Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + if self.suppress_numerals: previous_suppress_tokens = self.options.suppress_tokens numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer) @@ -215,19 +261,27 @@ def data(audio, segments): segments: List[SingleSegment] = [] batch_size = batch_size or self._batch_size total_segments = len(vad_segments) - for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)): + for idx, out in enumerate( + self.__call__( + data(audio, vad_segments), + batch_size=batch_size, + num_workers=num_workers, + ) + ): if print_progress: base_progress = ((idx + 1) / total_segments) * 100 - percent_complete = base_progress / 2 if combined_progress else base_progress + percent_complete = ( + base_progress / 2 if combined_progress else base_progress + ) print(f"Progress: {percent_complete:.2f}%...") - text = out['text'] + text = out["text"] if batch_size in [0, 1, None]: text = text[0] segments.append( { "text": text, - "start": round(vad_segments[idx]['start'], 3), - "end": round(vad_segments[idx]['end'], 3) + "start": round(vad_segments[idx]["start"], 3), + "end": round(vad_segments[idx]["end"], 3), } ) @@ -237,38 +291,48 @@ def data(audio, segments): # revert suppressed tokens if suppress_numerals is enabled if self.suppress_numerals: - self.options = self.options._replace(suppress_tokens=previous_suppress_tokens) + self.options = self.options._replace( + suppress_tokens=previous_suppress_tokens + ) return {"segments": segments, "language": language} - def detect_language(self, audio: np.ndarray): if audio.shape[0] < N_SAMPLES: - print("Warning: audio is shorter than 30s, language detection may be inaccurate.") + print( + "Warning: audio is shorter than 30s, language detection may be inaccurate." + ) model_n_mels = self.model.feat_kwargs.get("feature_size") - segment = log_mel_spectrogram(audio[: N_SAMPLES], - n_mels=model_n_mels if model_n_mels is not None else 80, - padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0]) + segment = log_mel_spectrogram( + audio[:N_SAMPLES], + n_mels=model_n_mels if model_n_mels is not None else 80, + padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0], + ) encoder_output = self.model.encode(segment) results = self.model.model.detect_language(encoder_output) language_token, language_probability = results[0][0] language = language_token[2:-2] - print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...") + print( + f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." + ) return language -def load_model(whisper_arch, - device, - device_index=0, - compute_type="float16", - asr_options=None, - language : Optional[str] = None, - vad_model=None, - vad_options=None, - model : Optional[WhisperModel] = None, - task="transcribe", - download_root=None, - threads=4): - '''Load a Whisper model for inference. + +def load_model( + whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language: Optional[str] = None, + vad_model=None, + vad_options=None, + model: Optional[WhisperModel] = None, + task="transcribe", + download_root=None, + threads=4, +): + """Load a Whisper model for inference. Args: whisper_arch: str - The name of the Whisper model to load. device: str - The device to load the model on. @@ -280,24 +344,33 @@ def load_model(whisper_arch, threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. Returns: A Whisper pipeline. - ''' + """ if whisper_arch.endswith(".en"): language = "en" - model = model or WhisperModel(whisper_arch, - device=device, - device_index=device_index, - compute_type=compute_type, - download_root=download_root, - cpu_threads=threads) + model = model or WhisperModel( + whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + cpu_threads=threads, + ) if language is not None: - tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language) + tokenizer = faster_whisper.tokenizer.Tokenizer( + model.hf_tokenizer, + model.model.is_multilingual, + task=task, + language=language, + ) else: - print("No language specified, language will be first be detected for each audio file (increases inference time).") + print( + "No language specified, language will be first be detected for each audio file (increases inference time)." + ) tokenizer = None - default_asr_options = { + default_asr_options = { "beam_size": 5, "best_of": 5, "patience": 1, @@ -323,6 +396,7 @@ def load_model(whisper_arch, "max_new_tokens": None, "clip_timestamps": None, "hallucination_silence_threshold": None, + "hotwords": None, } if asr_options is not None: @@ -331,12 +405,11 @@ def load_model(whisper_arch, suppress_numerals = default_asr_options["suppress_numerals"] del default_asr_options["suppress_numerals"] - default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + default_asr_options = faster_whisper.transcribe.TranscriptionOptions( + **default_asr_options + ) - default_vad_options = { - "vad_onset": 0.500, - "vad_offset": 0.363 - } + default_vad_options = {"vad_onset": 0.500, "vad_offset": 0.363} if vad_options is not None: default_vad_options.update(vad_options) @@ -344,7 +417,9 @@ def load_model(whisper_arch, if vad_model is not None: vad_model = vad_model else: - vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options) + vad_model = load_vad_model( + torch.device(device), use_auth_token=None, **default_vad_options + ) return FasterWhisperPipeline( model=model,