Skip to content

Commit

Permalink
add always_use_initial_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
mercury233 committed Mar 7, 2023
1 parent 8180fde commit bd54b68
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
always_use_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
Expand Down Expand Up @@ -97,6 +98,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
always_use_initial_prompt: bool
if True, the initial_prompt will be used to all windows, and condition_on_previous_text
will be ignored. Enabling this may make the text more consistent if the audio is long
and you set the initial_prompt properly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Expand Down Expand Up @@ -223,7 +229,11 @@ def new_segment(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
if always_use_initial_prompt:
decode_options["prompt"] = initial_prompt_tokens
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]

result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)

Expand Down Expand Up @@ -390,6 +400,7 @@ def cli():
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if if the audio is long and you set the initial_prompt properly.")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")

parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
Expand Down

0 comments on commit bd54b68

Please sign in to comment.