Skip to content

Commit

Permalink
Merge pull request m-bain#584 from DougTrajano/patch-1
Browse files Browse the repository at this point in the history
Move load_model after WhisperModel
  • Loading branch information
m-bain authored Nov 16, 2023
2 parents f5c544f + bd3aa03 commit ba30365
Showing 1 changed file with 93 additions and 93 deletions.
186 changes: 93 additions & 93 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,99 +22,6 @@ def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens

def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language : Optional[str] = None,
vad_options=None,
model : Optional[WhisperModel] = None,
task="transcribe",
download_root=None,
threads=4):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
'''

if whisper_arch.endswith(".en"):
language = "en"

model = model or WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
cpu_threads=threads)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None

default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"suppress_numerals": False,
}

if asr_options is not None:
default_asr_options.update(asr_options)

suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]

default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)

default_vad_options = {
"vad_onset": 0.500,
"vad_offset": 0.363
}

if vad_options is not None:
default_vad_options.update(vad_options)

vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)

return FasterWhisperPipeline(
model=model,
vad=vad_model,
options=default_asr_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)

class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Expand Down Expand Up @@ -341,3 +248,96 @@ def detect_language(self, audio: np.ndarray):
language = language_token[2:-2]
print(f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio...")
return language

def load_model(whisper_arch,
device,
device_index=0,
compute_type="float16",
asr_options=None,
language : Optional[str] = None,
vad_options=None,
model : Optional[WhisperModel] = None,
task="transcribe",
download_root=None,
threads=4):
'''Load a Whisper model for inference.
Args:
whisper_arch: str - The name of the Whisper model to load.
device: str - The device to load the model on.
compute_type: str - The compute type to use for the model.
options: dict - A dictionary of options to use for the model.
language: str - The language of the model. (use English for now)
model: Optional[WhisperModel] - The WhisperModel instance to use.
download_root: Optional[str] - The root directory to download the model to.
threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
Returns:
A Whisper pipeline.
'''

if whisper_arch.endswith(".en"):
language = "en"

model = model or WhisperModel(whisper_arch,
device=device,
device_index=device_index,
compute_type=compute_type,
download_root=download_root,
cpu_threads=threads)
if language is not None:
tokenizer = faster_whisper.tokenizer.Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
print("No language specified, language will be first be detected for each audio file (increases inference time).")
tokenizer = None

default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"suppress_numerals": False,
}

if asr_options is not None:
default_asr_options.update(asr_options)

suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]

default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options)

default_vad_options = {
"vad_onset": 0.500,
"vad_offset": 0.363
}

if vad_options is not None:
default_vad_options.update(vad_options)

vad_model = load_vad_model(torch.device(device), use_auth_token=None, **default_vad_options)

return FasterWhisperPipeline(
model=model,
vad=vad_model,
options=default_asr_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)

0 comments on commit ba30365

Please sign in to comment.