Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add always_use_initial_prompt #1040

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
always_use_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
Expand Down Expand Up @@ -102,6 +103,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.

always_use_initial_prompt: bool
if True, the initial_prompt will be used to all windows, and condition_on_previous_text
will be ignored. Enabling this may make the text more consistent if the audio is long
and you set the initial_prompt properly.

decode_options: dict
Keyword arguments to construct `DecodingOptions` instances

Expand Down Expand Up @@ -275,7 +281,11 @@ def new_segment(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
if always_use_initial_prompt:
decode_options["prompt"] = initial_prompt_tokens
else:
decode_options["prompt"] = all_tokens[prompt_reset_since:]

result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)

Expand Down Expand Up @@ -530,6 +540,7 @@ def valid_model_name(name):
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if the audio is long and you set the initial_prompt properly.")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")

parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
Expand Down