Skip to content

Commit f19e23f

Browse files
committed
whisper : restore decoder temperature fallbacks
I disabled this because there were many complaints about slow decoding. The current implementation does not allow batching the decoders when using the "best of" or "beam size" parameters, so the decoding time is proportional to the number of decoders, which is obviously not great. However, now there are even more complaints about wrong decodings and repetition. So, making a compromise by re-enabling the fallbacks, but defaulting to just 2 "best of" / "beam size" decoders. Also, the temperature step is increased from 0.2 to 0.4 - i.e. from maximum of 5 fallbacks to maximum of 2. Also, the stream example now has fallbacks enabled by default. close #471 #477 #508 #612 #719 #731
1 parent ea1f8a5 commit f19e23f

File tree

3 files changed

+25
-21
lines changed

3 files changed

+25
-21
lines changed

examples/main/main.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct whisper_params {
5757
int32_t duration_ms = 0;
5858
int32_t max_context = -1;
5959
int32_t max_len = 0;
60-
int32_t best_of = 5;
60+
int32_t best_of = 2;
6161
int32_t beam_size = -1;
6262

6363
float word_thold = 0.01f;

examples/stream/stream.cpp

+21-17
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct whisper_params {
4343

4444
bool speed_up = false;
4545
bool translate = false;
46+
bool no_fallback = false;
4647
bool print_special = false;
4748
bool no_context = true;
4849
bool no_timestamps = false;
@@ -73,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
7374
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
7475
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
7576
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
77+
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
7678
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
7779
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
7880
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@@ -94,22 +96,23 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
9496
fprintf(stderr, "\n");
9597
fprintf(stderr, "options:\n");
9698
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
97-
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
98-
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
99-
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
100-
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
101-
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
102-
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
103-
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
104-
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
105-
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
106-
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
107-
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
108-
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
109-
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
110-
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
111-
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
112-
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
99+
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
100+
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
101+
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
102+
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
103+
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
104+
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
105+
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
106+
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
107+
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
108+
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
109+
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
110+
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
111+
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
112+
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
113+
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
114+
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
115+
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
113116
fprintf(stderr, "\n");
114117
}
115118

@@ -297,7 +300,8 @@ int main(int argc, char ** argv) {
297300
wparams.speed_up = params.speed_up;
298301

299302
// disable temperature fallback
300-
wparams.temperature_inc = -1.0f;
303+
//wparams.temperature_inc = -1.0f;
304+
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
301305

302306
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
303307
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

whisper.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -3220,7 +3220,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
32203220
/*.max_initial_ts =*/ 1.0f,
32213221
/*.length_penalty =*/ -1.0f,
32223222

3223-
/*.temperature_inc =*/ 0.0f, // TODO: temporary disabled until improve performance
3223+
/*.temperature_inc =*/ 0.4f,
32243224
/*.entropy_thold =*/ 2.4f,
32253225
/*.logprob_thold =*/ -1.0f,
32263226
/*.no_speech_thold =*/ 0.6f,
@@ -3252,13 +3252,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
32523252
case WHISPER_SAMPLING_GREEDY:
32533253
{
32543254
result.greedy = {
3255-
/*.best_of =*/ 1,
3255+
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
32563256
};
32573257
} break;
32583258
case WHISPER_SAMPLING_BEAM_SEARCH:
32593259
{
32603260
result.beam_search = {
3261-
/*.beam_size =*/ 5,
3261+
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
32623262

32633263
/*.patience =*/ -1.0f,
32643264
};

0 commit comments

Comments
 (0)