diff --git a/.gitignore b/.gitignore index a1adabaf40b..b30a1d19f01 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ build/ build-em/ build-debug/ build-release/ +build-rwdi/ build-static/ build-cublas/ build-no-accel/ diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 85789d35de6..d5612bd0a00 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -22,6 +22,11 @@ #include #include +bool file_exists(const std::string & fname) { + std::ifstream f(fname.c_str()); + return f.good(); +} + // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -31,10 +36,13 @@ struct whisper_params { int32_t max_tokens = 32; int32_t audio_ctx = 0; - float vad_thold = 0.6f; - float freq_thold = 100.0f; + float vad_thold = 0.6f; + float freq_thold = 100.0f; + float grammar_penalty = 100.0f; + grammar_parser::parse_state grammar_parsed; + bool speed_up = false; bool translate = false; bool print_special = false; @@ -46,6 +54,7 @@ struct whisper_params { std::string fname_out; std::string commands; std::string prompt; + std::string context; std::string grammar; }; @@ -76,6 +85,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } else if ( arg == "--grammar") { params.grammar = argv[++i]; } else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } else { @@ -111,21 +121,30 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str()); + fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str()); fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); fprintf(stderr, "\n"); } -std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector & pcmf32, float & prob, int64_t & t_ms) { +std::string transcribe( + whisper_context * ctx, + const whisper_params & params, + const std::vector & pcmf32, + const std::string & grammar_rule, + float & logprob_min, + float & logprob_sum, + int & n_tokens, + int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); - prob = 0.0f; + logprob_min = 0.0f; + logprob_sum = 0.0f; + n_tokens = 0; t_ms = 0; - grammar_parser::parse_state parsed_grammar; - std::vector grammar_rules; - - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + //whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); wparams.print_progress = false; wparams.print_special = params.print_special; @@ -133,20 +152,30 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con wparams.print_timestamps = !params.no_timestamps; wparams.translate = params.translate; wparams.no_context = true; + wparams.no_timestamps = params.no_timestamps; wparams.single_segment = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; - wparams.audio_ctx = params.audio_ctx; - wparams.speed_up = params.speed_up; + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - grammar_rules = parsed_grammar.c_rules(); + wparams.temperature = 0.4f; + wparams.temperature_inc = 1.0f; + wparams.greedy.best_of = 5; + + wparams.beam_search.beam_size = 5; + + wparams.initial_prompt = params.context.data(); + + const auto & grammar_parsed = params.grammar_parsed; + auto grammar_rules = grammar_parsed.c_rules(); + + if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) { wparams.grammar_rules = grammar_rules.data(); wparams.n_grammar_rules = grammar_rules.size(); - wparams.i_start_rule = parsed_grammar.symbol_ids.at("root"); + wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule); wparams.grammar_penalty = params.grammar_penalty; } @@ -154,7 +183,6 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con return ""; } - int prob_n = 0; std::string result; const int n_segments = whisper_full_n_segments(ctx); @@ -163,19 +191,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con result += text; - const int n_tokens = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n_tokens; ++j) { + const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { const auto token = whisper_full_get_token_data(ctx, i, j); - prob += token.p; - ++prob_n; + if(token.plog > 0.0f) exit(0); + logprob_min = std::min(logprob_min, token.plog); + logprob_sum += token.plog; + ++n_tokens; } } - if (prob_n > 0) { - prob /= prob_n; - } - const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); @@ -266,7 +292,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const fprintf(stderr, " ]\n"); } - std::string k_prompt = "select one from the available words: "; + std::string k_prompt = "select one from the available words: "; for (int i = 0; i < (int) allowed_commands.size(); ++i) { if (i > 0) { k_prompt += ", "; @@ -434,7 +460,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi bool is_running = true; bool ask_prompt = true; - float prob = 0.0f; + float logprob_min = 0.0f; + float logprob_sum = 0.0f; + int n_tokens = 0; std::vector pcmf32_cur; @@ -472,7 +500,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // detect the commands audio.get(params.command_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms)); const auto words = get_words(txt); @@ -508,18 +536,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi // general-purpose mode // freely transcribe the voice into text -int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) { +int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) { bool is_running = true; bool have_prompt = false; bool ask_prompt = true; - float prob0 = 0.0f; - float prob = 0.0f; + float logprob_min0 = 0.0f; + float logprob_min = 0.0f; + + float logprob_sum0 = 0.0f; + float logprob_sum = 0.0f; + + int n_tokens0 = 0; + int n_tokens = 0; std::vector pcmf32_cur; std::vector pcmf32_prompt; - const std::string k_prompt = "Ok Whisper, start listening for commands."; + std::string k_prompt = "Ok Whisper, start listening for commands."; + if (!params.prompt.empty()) { + k_prompt = params.prompt; + } fprintf(stderr, "\n"); fprintf(stderr, "%s: general-purpose mode\n", __func__); @@ -552,9 +589,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // wait for activation phrase audio.get(params.prompt_ms, pcmf32_cur); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms)); + + const float p = 100.0f * std::exp(logprob_min0); - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p); const float sim = similarity(txt, k_prompt); @@ -575,12 +614,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud // we have heard the activation phrase, now detect the commands audio.get(params.command_ms, pcmf32_cur); + //printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE); + //printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE); + + // prepend 3 second of silence + pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f); + // prepend the prompt audio pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms)); + const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms)); - prob = 100.0f*(prob - prob0); + //const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0)); + const float p = 100.0f * std::exp(logprob_min); //fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); @@ -604,6 +650,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud } } + fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); if (best_len == 0) { fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__); } else { @@ -678,13 +725,23 @@ int main(int argc, char ** argv) { int ret_val = 0; if (!params.grammar.empty()) { - auto parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + auto & grammar = params.grammar_parsed; + if (file_exists(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { + if (grammar.rules.empty()) { ret_val = 1; } else { fprintf(stderr, "%s: grammar:\n", __func__); - grammar_parser::print_grammar(stderr, parsed_grammar); + grammar_parser::print_grammar(stderr, grammar); fprintf(stderr, "\n"); } } @@ -692,7 +749,7 @@ int main(int argc, char ** argv) { if (ret_val == 0) { if (!params.commands.empty()) { ret_val = process_command_list(ctx, audio, params); - } else if (!params.prompt.empty()) { + } else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) { ret_val = always_prompt_transcription(ctx, audio, params); } else { ret_val = process_general_transcription(ctx, audio, params); diff --git a/examples/grammar-parser.cpp b/examples/grammar-parser.cpp index b5b607fa9d0..2daaaef4504 100644 --- a/examples/grammar-parser.cpp +++ b/examples/grammar-parser.cpp @@ -413,7 +413,7 @@ namespace grammar_parser { } } - std::vector parse_state::c_rules() { + std::vector parse_state::c_rules() const{ std::vector ret; for (const auto & rule : rules) { ret.push_back(rule.data()); diff --git a/examples/grammar-parser.h b/examples/grammar-parser.h index ef0ec44174f..47d019c33e1 100644 --- a/examples/grammar-parser.h +++ b/examples/grammar-parser.h @@ -21,7 +21,7 @@ namespace grammar_parser { std::map symbol_ids; std::vector> rules; - std::vector c_rules(); + std::vector c_rules() const; }; parse_state parse(const char * src); diff --git a/grammars/assistant.gbnf b/grammars/assistant.gbnf new file mode 100644 index 00000000000..c445778a11d --- /dev/null +++ b/grammars/assistant.gbnf @@ -0,0 +1,57 @@ +# - "turn on lights." +# - "set thermostat to 22." +# - "increase TV by 10." +# - "decrease oven by 50." +# - "play music." +# - "stop podcast." +# - "schedule cleaning at 3pm." +# - "cancel cleaning." +# - "remind me to buy milk at 5pm." +# - "show me security system." +# - "hide washing machine." +# - "what is the lights status?" +# - "what is the current thermostat value?" +# - "what is the security system status?" +# - "what is the door lock status?" +# - "what is the camera battery level?" +# - "what is the weather like today?" +# - "what is the forecast for tomorrow?" +# - "what is the time?" +# - "what is my schedule for today?" +# - "what tasks do I have?" +# - "what reminders do I have?" +# +# example: +# +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10 +# + +root ::= init " " (command | question) "." +prompt ::= init + +# leading space is very important! +init ::= " Ok Whisper, start listening for commands." + +command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value | + "Increase " device " by " value | "Decrease " device " by " value | + "Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task | + "Remind me to " task " at " time | "Show me " device | "Hide " device + +question ::= "What is the " device " status?" | "What is the current " device " value?" | + "What is the " device " temperature?" | "What is the " device " humidity?" | + "What is the " device " power consumption?" | "What is the " device " battery level?" | + "What is the weather like today?" | "What is the forecast for tomorrow?" | + "What is the time?" | "What is my schedule for today?" | "What tasks do I have?" | + "What reminders do I have?" + +device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" | + "music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" | + "vacuum cleaner" + +value ::= [0-9]+ + +media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie" + +task ::= [a-zA-Z]+ (" " [a-zA-Z]+)? + +time ::= [0-9] [0-9]? ("am" | "pm")? diff --git a/grammars/chess.gbnf b/grammars/chess.gbnf new file mode 100644 index 00000000000..ec8c8423c85 --- /dev/null +++ b/grammars/chess.gbnf @@ -0,0 +1,29 @@ +# - bishop to c3 +# - rook to d4 +# - knight to e5 +# - d4 d5 knight to c3 +# - c3 queen to d4 king b1 +# - pawn to a1 bishop to b2 knight to c3 +# +# The prompt (--prompt) is the initial phrase that the user has to say. +# This is used to prime Whisper with how the user is expected to speak. +# +# Provide long context (--context) with sample moves to help Whisper decode the correct sequence. +# Longer context is better, but it slightly increases the processing time. +# +# example: +# +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/chess.gbnf --prompt "rook to b4, f3," --context "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8," --grammar-penalty 100 +# + +root ::= init move move? move? "." +prompt ::= init "." + +# leading space is very important! +init ::= " rook to b4, f3" + +move ::= ", " ((piece | pawn | king) " " "to "?)? [a-h] [1-8] + +piece ::= "bishop" | "rook" | "knight" | "queen" +king ::= "king" +pawn ::= "pawn" diff --git a/grammars/colors.gbnf b/grammars/colors.gbnf new file mode 100644 index 00000000000..1d9945054b0 --- /dev/null +++ b/grammars/colors.gbnf @@ -0,0 +1,16 @@ +# - red +# - green +# - blue +# +# example: +# +# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/colors.gbnf --prompt "red, green, blue," --context "green, red, blue," +# + +root ::= init color "." +prompt ::= init "." + +# leading space is very important! +init ::= " red, green, blue" + +color ::= ", " ("red" | "green" | "blue") diff --git a/whisper.cpp b/whisper.cpp index 078841b391c..4753232b6ad 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3872,35 +3872,38 @@ static void whisper_suppress_invalid_grammar( return; } - // bool allow_eot = false; - // for (const auto & stack : grammar.stacks) { - // if (stack.empty()) { - // allow_eot = true; - // break; - // } - // } + //bool allow_eot = false; + //for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + //} + + const whisper_token eot = whisper_token_eot(&ctx); std::vector, whisper_partial_utf8>> candidates_decoded; std::vector candidates_grammar; - size_t size = logits.size(); - for (whisper_token id = 0; id < size; ++id) { + for (whisper_token id = 0; id < eot; ++id) { const std::string & text = ctx.vocab.id_to_token[id]; - if (!text.empty() && text.rfind("[_", 0) != 0) { + if (!text.empty()) { candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + for (const auto & reject : rejects) { - if (logits[reject.id] > 0) { - logits[reject.id] /= params.grammar_penalty; - } else { - logits[reject.id] *= params.grammar_penalty; - } + logits[reject.id] -= params.grammar_penalty; } - // fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); + + // when the grammar allows a continuation, we penalize the end-of-text token + //if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); } static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { @@ -3908,10 +3911,10 @@ static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar return; } - // fprintf(stderr, "Accept: '%s'", ctx.vocab.id_to_token[token].c_str()); + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); const std::string & text = ctx.vocab.id_to_token[token]; - + if (text.rfind("[_", 0) == 0) { // fprintf(stderr, " (skipped)\n"); return; @@ -3952,6 +3955,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.translate =*/ false, /*.no_context =*/ true, + /*.no_timestamps =*/ false, /*.single_segment =*/ false, /*.print_special =*/ false, /*.print_progress =*/ true, @@ -4015,7 +4019,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.grammar_rules =*/ nullptr, /*.n_grammar_rules =*/ 0, /*.i_start_rule =*/ 0, - /*.grammar_penalty =*/ 1000.0f, + /*.grammar_penalty =*/ 100.0f, }; switch (strategy) { @@ -4167,6 +4171,11 @@ static void whisper_process_logits( // suppress <|notimestamps|> token // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; @@ -4181,12 +4190,18 @@ static void whisper_process_logits( logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); } - whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 if (params.suppress_non_speech_tokens) { @@ -4293,10 +4308,33 @@ static void whisper_process_logits( //log("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { + //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; } + } else if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } } } } @@ -4314,32 +4352,55 @@ static void whisper_process_logits( #if 0 // print first 100 logits - token string : logit - for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto prob = probs[i]; - const auto logit = logits[i]; - const auto logprob = logprobs[i]; - printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); } // "And", "and", " And", " and" - printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); - printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); - printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); - printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); - printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); - - printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); - printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); - printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); - printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); - printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); - - printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); - printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); - printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); - printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); - printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); #endif } @@ -4460,8 +4521,11 @@ static std::vector whisper_sample_token_topk( ptsum = sum_ts; } + std::discrete_distribution<> dist(probs.begin(), probs.end()); + for (int i = 0; i < k; ++i) { - const auto id = logits_id[i].second; + const auto id = dist(state.rng); + //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4671,7 +4735,7 @@ int whisper_full_with_state( state->exp_n_audio_ctx = params.audio_ctx; // these tokens determine the task that will be performed - std::vector prompt_init = { whisper_token_sot(ctx) }; + std::vector prompt_init = { whisper_token_sot(ctx), }; if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; @@ -4682,6 +4746,9 @@ int whisper_full_with_state( prompt_init.push_back(whisper_token_transcribe(ctx)); } } + if (params.no_timestamps) { + prompt_init.push_back(whisper_token_not(ctx)); + } int seek = seek_start; @@ -4766,7 +4833,7 @@ int whisper_full_with_state( n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -4923,6 +4990,10 @@ int whisper_full_with_state( continue; } + if (cur_c >= beam_candidates.size()) { + cur_c = 0; + } + auto & cur = beam_candidates[cur_c++]; while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { diff --git a/whisper.h b/whisper.h index 23f61ed5b06..fe50e73fb2b 100644 --- a/whisper.h +++ b/whisper.h @@ -389,6 +389,7 @@ extern "C" { bool translate; bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps bool single_segment; // force single segment output (useful for streaming) bool print_special; // print special tokens (e.g. , , , etc.) bool print_progress; // print progress information