diff --git a/whisperx/asr.py b/whisperx/asr.py index ac816ff9..fe3c7566 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -262,7 +262,7 @@ def load_model(whisper_arch, compute_type="float16", asr_options=None, language : Optional[str] = None, - vad_model=None, + vad_model_fp=None, vad_options=None, model : Optional[WhisperModel] = None, task="transcribe", @@ -275,6 +275,7 @@ def load_model(whisper_arch, compute_type: str - The compute type to use for the model. options: dict - A dictionary of options to use for the model. language: str - The language of the model. (use English for now) + vad_model_fp: str - File path to the VAD model to use model: Optional[WhisperModel] - The WhisperModel instance to use. download_root: Optional[str] - The root directory to download the model to. threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. @@ -338,8 +339,8 @@ def load_model(whisper_arch, if vad_options is not None: default_vad_options.update(vad_options) - if vad_model is not None: - vad_model = vad_model + if vad_model_fp is not None: + vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options, model_fp=vad_model_fp) else: vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)