diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 20f01477e..034e86aa0 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -43,6 +43,7 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + always_use_initial_prompt: bool = False, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -97,6 +98,11 @@ def transcribe( "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those word correctly. + always_use_initial_prompt: bool + if True, the initial_prompt will be used to all windows, and condition_on_previous_text + will be ignored. Enabling this may make the text more consistent if the audio is long + and you set the initial_prompt properly. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -223,7 +229,11 @@ def new_segment( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] + if always_use_initial_prompt: + decode_options["prompt"] = initial_prompt_tokens + else: + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -390,6 +400,7 @@ def cli(): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if if the audio is long and you set the initial_prompt properly.") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")