Skip to content

Commit

Permalink
Add more user input via api-request
Browse files Browse the repository at this point in the history
also some clean up
  • Loading branch information
felrock committed Nov 17, 2023
1 parent 24c208f commit 67e6d2c
Showing 1 changed file with 82 additions and 74 deletions.
156 changes: 82 additions & 74 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
#endif

using namespace httplib;
using json = nlohmann::json;

Expand Down Expand Up @@ -60,16 +56,6 @@ int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}

// helper function to replace substrings
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
for (size_t pos = 0; ; pos += replace.length()) {
pos = s.find(search, pos);
if (pos == std::string::npos) break;
s.erase(pos, search.length());
s.insert(pos, replace);
}
}

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand Down Expand Up @@ -112,14 +98,53 @@ struct whisper_params {
std::string openvino_encode_device = "CPU";
};

void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params,
const server_params& sparams) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] \n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
// server params
fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port);
fprintf(stderr, "\n");
}

bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) {
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];

if (arg == "-h" || arg == "--help") {
whisper_print_usage(argc, argv, params);
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
Expand Down Expand Up @@ -151,53 +176,20 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
else if (arg == "-ad" || arg == "--port") { params.openvino_encode_device = argv[++i]; }
else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
whisper_print_usage(argc, argv, params);
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}
}

return true;
}

void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) {
fprintf(stderr, "\n");
fprintf(stderr, "usage: %s [options] \n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors);
fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms);
fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n);
fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms);
fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context);
fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len);
fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false");
fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of);
fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size);
fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold);
fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold);
fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold);
// fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false");
fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false");
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, "\n");
}

struct whisper_print_user_data {
const whisper_params * params;

Expand Down Expand Up @@ -367,23 +359,24 @@ char *escape_double_quotes_and_backslashes(const char *str) {

int main(int argc, char ** argv) {
whisper_params params;
server_params sparams;

std::mutex whisper_mutex;

if (whisper_params_parse(argc, argv, params) == false) {
whisper_print_usage(argc, argv, params);
if (whisper_params_parse(argc, argv, params, sparams) == false) {
whisper_print_usage(argc, argv, params, sparams);
return 1;
}

if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
whisper_print_usage(argc, argv, params);
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}

if (params.diarize && params.tinydiarize) {
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
whisper_print_usage(argc, argv, params);
whisper_print_usage(argc, argv, params, sparams);
exit(0);
}

Expand All @@ -400,22 +393,41 @@ int main(int argc, char ** argv) {

Server svr;

std::string default_content = "<html>hello</html>";
std::string const default_content = "<html>hello</html>";

// this is only called if no index.html is found in the public --path
svr.Get("/", [&default_content](const Request &, Response &res){
res.set_content(default_content.c_str(), default_content.size(), "text/html");
return false;
});

svr.Post("/whisper", [&](const Request &req, Response &res){
svr.Post("/inference", [&](const Request &req, Response &res){

// aquire whisper model mutex lock
whisper_mutex.lock();

// user audio file

auto audio_file = req.get_file_value("audio_file");

// user model configuration
if (req.has_param("offset-t"))
{
params.offset_t_ms = std::stoi(req.get_param_value("offset-t"));
}
if (req.has_param("offset-n"))
{
params.offset_n = std::stoi(req.get_param_value("offset-n"));
}
if (req.has_param("duration"))
{
params.duration_ms = std::stoi(req.get_param_value("duration"));
}
if (req.has_param("max-context"))
{
params.max_context = std::stoi(req.get_param_value("max-context"));
}
// TODO add all

std::string filename{audio_file.filename};
printf("Received request: %s\n", filename.c_str());

Expand All @@ -434,6 +446,7 @@ int main(int argc, char ** argv) {
whisper_mutex.unlock();
return;
}
// remove temp file
std::remove(filename.c_str());

printf("Successfully loaded %s\n", filename.c_str());
Expand Down Expand Up @@ -569,26 +582,21 @@ int main(int argc, char ** argv) {
});

// set timeouts and change hostname and port
int read_timeout = 600;
int write_timeout = 600;
std::string hostname = "localhost";
std::string public_path = "examples/server/public";
int port = 8080;

svr.set_read_timeout(read_timeout);
svr.set_write_timeout(write_timeout);
svr.set_read_timeout(sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);

if (!svr.bind_to_port(hostname, port))
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", hostname.c_str(), port);
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
sparams.hostname.c_str(), sparams.port);
return 1;
}

// Set the base directory for serving static files
svr.set_base_dir(public_path);
svr.set_base_dir(sparams.public_path);

// to make it ctrl+clickable:
printf("\nllama server listening at http://%s:%d\n\n", hostname.c_str(), port);
printf("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);

if (!svr.listen_after_bind())
{
Expand Down

0 comments on commit 67e6d2c

Please sign in to comment.