Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

whisper : fine-tuning grammar functionality #1

Merged
merged 5 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ build/
build-em/
build-debug/
build-release/
build-rwdi/
build-static/
build-cublas/
build-no-accel/
Expand Down
133 changes: 95 additions & 38 deletions examples/command/command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
#include <vector>
#include <map>

bool file_exists(const std::string & fname) {
std::ifstream f(fname.c_str());
return f.good();
}

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand All @@ -31,10 +36,13 @@ struct whisper_params {
int32_t max_tokens = 32;
int32_t audio_ctx = 0;

float vad_thold = 0.6f;
float freq_thold = 100.0f;
float vad_thold = 0.6f;
float freq_thold = 100.0f;

float grammar_penalty = 100.0f;

grammar_parser::parse_state grammar_parsed;

bool speed_up = false;
bool translate = false;
bool print_special = false;
Expand All @@ -46,6 +54,7 @@ struct whisper_params {
std::string fname_out;
std::string commands;
std::string prompt;
std::string context;
std::string grammar;
};

Expand Down Expand Up @@ -76,6 +85,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; }
else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; }
else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; }
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
else {
Expand Down Expand Up @@ -111,50 +121,68 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str());
fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str());
fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str());
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty);
fprintf(stderr, "\n");
}

std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
std::string transcribe(
whisper_context * ctx,
const whisper_params & params,
const std::vector<float> & pcmf32,
const std::string & grammar_rule,
float & logprob_min,
float & logprob_sum,
int & n_tokens,
int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();

prob = 0.0f;
logprob_min = 0.0f;
logprob_sum = 0.0f;
n_tokens = 0;
t_ms = 0;

grammar_parser::parse_state parsed_grammar;
std::vector<const whisper_grammar_element *> grammar_rules;

whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
//whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH);

wparams.print_progress = false;
wparams.print_special = params.print_special;
wparams.print_realtime = false;
wparams.print_timestamps = !params.no_timestamps;
wparams.translate = params.translate;
wparams.no_context = true;
wparams.no_timestamps = params.no_timestamps;
wparams.single_segment = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;

wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;
wparams.audio_ctx = params.audio_ctx;
wparams.speed_up = params.speed_up;

if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
grammar_rules = parsed_grammar.c_rules();
wparams.temperature = 0.4f;
wparams.temperature_inc = 1.0f;
wparams.greedy.best_of = 5;

wparams.beam_search.beam_size = 5;

wparams.initial_prompt = params.context.data();

const auto & grammar_parsed = params.grammar_parsed;
auto grammar_rules = grammar_parsed.c_rules();

if (!params.grammar_parsed.rules.empty() && !grammar_rule.empty()) {
wparams.grammar_rules = grammar_rules.data();
wparams.n_grammar_rules = grammar_rules.size();
wparams.i_start_rule = parsed_grammar.symbol_ids.at("root");
wparams.i_start_rule = grammar_parsed.symbol_ids.at(grammar_rule);
wparams.grammar_penalty = params.grammar_penalty;
}

if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
return "";
}

int prob_n = 0;
std::string result;

const int n_segments = whisper_full_n_segments(ctx);
Expand All @@ -163,19 +191,17 @@ std::string transcribe(whisper_context * ctx, const whisper_params & params, con

result += text;

const int n_tokens = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n_tokens; ++j) {
const int n = whisper_full_n_tokens(ctx, i);
for (int j = 0; j < n; ++j) {
const auto token = whisper_full_get_token_data(ctx, i, j);

prob += token.p;
++prob_n;
if(token.plog > 0.0f) exit(0);
logprob_min = std::min(logprob_min, token.plog);
logprob_sum += token.plog;
++n_tokens;
}
}

if (prob_n > 0) {
prob /= prob_n;
}

const auto t_end = std::chrono::high_resolution_clock::now();
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();

Expand Down Expand Up @@ -266,7 +292,7 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const
fprintf(stderr, " ]\n");
}

std::string k_prompt = "select one from the available words: ";
std::string k_prompt = "select one from the available words: ";
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
if (i > 0) {
k_prompt += ", ";
Expand Down Expand Up @@ -434,7 +460,9 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
bool is_running = true;
bool ask_prompt = true;

float prob = 0.0f;
float logprob_min = 0.0f;
float logprob_sum = 0.0f;
int n_tokens = 0;

std::vector<float> pcmf32_cur;

Expand Down Expand Up @@ -472,7 +500,7 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi
// detect the commands
audio.get(params.command_ms, pcmf32_cur);

const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "", logprob_min, logprob_sum, n_tokens, t_ms));

const auto words = get_words(txt);

Expand Down Expand Up @@ -508,18 +536,27 @@ int always_prompt_transcription(struct whisper_context * ctx, audio_async & audi

// general-purpose mode
// freely transcribe the voice into text
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params &params) {
int process_general_transcription(struct whisper_context * ctx, audio_async & audio, const whisper_params & params) {
bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;

float prob0 = 0.0f;
float prob = 0.0f;
float logprob_min0 = 0.0f;
float logprob_min = 0.0f;

float logprob_sum0 = 0.0f;
float logprob_sum = 0.0f;

int n_tokens0 = 0;
int n_tokens = 0;

std::vector<float> pcmf32_cur;
std::vector<float> pcmf32_prompt;

const std::string k_prompt = "Ok Whisper, start listening for commands.";
std::string k_prompt = "Ok Whisper, start listening for commands.";
if (!params.prompt.empty()) {
k_prompt = params.prompt;
}

fprintf(stderr, "\n");
fprintf(stderr, "%s: general-purpose mode\n", __func__);
Expand Down Expand Up @@ -552,9 +589,11 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// wait for activation phrase
audio.get(params.prompt_ms, pcmf32_cur);

const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "prompt", logprob_min0, logprob_sum0, n_tokens0, t_ms));

const float p = 100.0f * std::exp(logprob_min0);

fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms, p = %.2f%%)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms, p);

const float sim = similarity(txt, k_prompt);

Expand All @@ -575,12 +614,19 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
// we have heard the activation phrase, now detect the commands
audio.get(params.command_ms, pcmf32_cur);

//printf("len prompt: %.4f\n", pcmf32_prompt.size() / (float) WHISPER_SAMPLE_RATE);
//printf("len command: %.4f\n", pcmf32_cur.size() / (float) WHISPER_SAMPLE_RATE);

// prepend 3 second of silence
pcmf32_cur.insert(pcmf32_cur.begin(), 3.0f*WHISPER_SAMPLE_RATE, 0.0f);

// prepend the prompt audio
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());

const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, "root", logprob_min, logprob_sum, n_tokens, t_ms));

prob = 100.0f*(prob - prob0);
//const float p = 100.0f * std::exp((logprob - logprob0) / (n_tokens - n_tokens0));
const float p = 100.0f * std::exp(logprob_min);

//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());

Expand All @@ -604,6 +650,7 @@ int process_general_transcription(struct whisper_context * ctx, audio_async &aud
}
}

fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
if (best_len == 0) {
fprintf(stdout, "%s: WARNING: command not recognized, try again\n", __func__);
} else {
Expand Down Expand Up @@ -678,21 +725,31 @@ int main(int argc, char ** argv) {
int ret_val = 0;

if (!params.grammar.empty()) {
auto parsed_grammar = grammar_parser::parse(params.grammar.c_str());
auto & grammar = params.grammar_parsed;
if (file_exists(params.grammar.c_str())) {
// read grammar from file
std::ifstream ifs(params.grammar.c_str());
const std::string txt = std::string((std::istreambuf_iterator<char>(ifs)), std::istreambuf_iterator<char>());
grammar = grammar_parser::parse(txt.c_str());
} else {
// read grammar from string
grammar = grammar_parser::parse(params.grammar.c_str());
}

// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
if (grammar.rules.empty()) {
ret_val = 1;
} else {
fprintf(stderr, "%s: grammar:\n", __func__);
grammar_parser::print_grammar(stderr, parsed_grammar);
grammar_parser::print_grammar(stderr, grammar);
fprintf(stderr, "\n");
}
}

if (ret_val == 0) {
if (!params.commands.empty()) {
ret_val = process_command_list(ctx, audio, params);
} else if (!params.prompt.empty()) {
} else if (!params.prompt.empty() && params.grammar_parsed.rules.empty()) {
ret_val = always_prompt_transcription(ctx, audio, params);
} else {
ret_val = process_general_transcription(ctx, audio, params);
Expand Down
2 changes: 1 addition & 1 deletion examples/grammar-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ namespace grammar_parser {
}
}

std::vector<const whisper_grammar_element *> parse_state::c_rules() {
std::vector<const whisper_grammar_element *> parse_state::c_rules() const{
std::vector<const whisper_grammar_element *> ret;
for (const auto & rule : rules) {
ret.push_back(rule.data());
Expand Down
2 changes: 1 addition & 1 deletion examples/grammar-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace grammar_parser {
std::map<std::string, uint32_t> symbol_ids;
std::vector<std::vector<whisper_grammar_element>> rules;

std::vector<const whisper_grammar_element *> c_rules();
std::vector<const whisper_grammar_element *> c_rules() const;
};

parse_state parse(const char * src);
Expand Down
57 changes: 57 additions & 0 deletions grammars/assistant.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# - "turn on lights."
# - "set thermostat to 22."
# - "increase TV by 10."
# - "decrease oven by 50."
# - "play music."
# - "stop podcast."
# - "schedule cleaning at 3pm."
# - "cancel cleaning."
# - "remind me to buy milk at 5pm."
# - "show me security system."
# - "hide washing machine."
# - "what is the lights status?"
# - "what is the current thermostat value?"
# - "what is the security system status?"
# - "what is the door lock status?"
# - "what is the camera battery level?"
# - "what is the weather like today?"
# - "what is the forecast for tomorrow?"
# - "what is the time?"
# - "what is my schedule for today?"
# - "what tasks do I have?"
# - "what reminders do I have?"
#
# example:
#
# ./command -m ./models/ggml-tiny.en.bin -t 8 --grammar ./grammars/assistant.gbnf --prompt "Ok Whisper, start listening for commands." --context "Whisper is a home assistant. It recognizes voice commands. Time is 11pm." --grammar-penalty 10
#

root ::= init " " (command | question) "."
prompt ::= init

# leading space is very important!
init ::= " Ok Whisper, start listening for commands."

command ::= "Turn " ("on" | "off") " " device | "Set " device " to " value |
"Increase " device " by " value | "Decrease " device " by " value |
"Play " media | "Stop " media | "Schedule " task " at " time | "Cancel " task |
"Remind me to " task " at " time | "Show me " device | "Hide " device

question ::= "What is the " device " status?" | "What is the current " device " value?" |
"What is the " device " temperature?" | "What is the " device " humidity?" |
"What is the " device " power consumption?" | "What is the " device " battery level?" |
"What is the weather like today?" | "What is the forecast for tomorrow?" |
"What is the time?" | "What is my schedule for today?" | "What tasks do I have?" |
"What reminders do I have?"

device ::= "lights" | "thermostat" | "security system" | "door lock" | "camera" | "speaker" | "TV" |
"music player" | "coffee machine" | "oven" | "refrigerator" | "washing machine" |
"vacuum cleaner"

value ::= [0-9]+

media ::= "music" | "radio" | "podcast" | "audiobook" | "TV show" | "movie"

task ::= [a-zA-Z]+ (" " [a-zA-Z]+)?

time ::= [0-9] [0-9]? ("am" | "pm")?
Loading