From 90175ee13ffeeca501a5623d2ff1d064ac568b45 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 09:05:50 -0300 Subject: [PATCH 1/5] Move main.cpp to run.cpp Signed-off-by: Thiago Padilha --- CMakeLists.txt | 2 +- Makefile | 4 ++-- main.cpp => run.cpp | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename main.cpp => run.cpp (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index d952afb4ff72b..92b45615afe04 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,7 @@ target_link_libraries(llama PRIVATE utils ggml ${LLAMA_EXTRA_LIBS}) # Executables # -add_executable(main main.cpp) +add_executable(main run.cpp) target_link_libraries(main PRIVATE llama ggml utils) add_executable(quantize quantize.cpp) diff --git a/Makefile b/Makefile index edb0c64c82361..a8f09db7a39f8 100644 --- a/Makefile +++ b/Makefile @@ -229,8 +229,8 @@ utils.o: utils.cpp utils.h clean: rm -f *.o main quantize -main: main.cpp ggml.o llama.o utils.o - $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o -o main $(LDFLAGS) +main: run.cpp ggml.o llama.o utils.o + $(CXX) $(CXXFLAGS) run.cpp ggml.o llama.o utils.o -o main $(LDFLAGS) @echo "\x1b[36mrun ./main -h for help\x1b[0m" quantize: quantize.cpp ggml.o llama.o utils.o diff --git a/main.cpp b/run.cpp similarity index 100% rename from main.cpp rename to run.cpp From d7d53b84dbf709415c63bf1e07fff9ea5cb23298 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 09:16:33 -0300 Subject: [PATCH 2/5] Add main.cpp back and invoke "run" from it Signed-off-by: Thiago Padilha --- CMakeLists.txt | 4 +++- Makefile | 7 +++++-- main.cpp | 5 +++++ run.cpp | 2 +- run.h | 3 +++ 5 files changed, 17 insertions(+), 4 deletions(-) create mode 100644 main.cpp create mode 100644 run.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 92b45615afe04..4db24fbbb598c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,9 @@ target_link_libraries(llama PRIVATE utils ggml ${LLAMA_EXTRA_LIBS}) # Executables # -add_executable(main run.cpp) +add_executable(main + main.cpp + run.cpp) target_link_libraries(main PRIVATE llama ggml utils) add_executable(quantize quantize.cpp) diff --git a/Makefile b/Makefile index a8f09db7a39f8..2f11ea166e299 100644 --- a/Makefile +++ b/Makefile @@ -226,11 +226,14 @@ llama.o: llama.cpp llama.h utils.o: utils.cpp utils.h $(CXX) $(CXXFLAGS) -c utils.cpp -o utils.o +run.o: run.cpp run.h + $(CXX) $(CXXFLAGS) -c run.cpp -o run.o + clean: rm -f *.o main quantize -main: run.cpp ggml.o llama.o utils.o - $(CXX) $(CXXFLAGS) run.cpp ggml.o llama.o utils.o -o main $(LDFLAGS) +main: main.cpp ggml.o llama.o utils.o run.o + $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS) @echo "\x1b[36mrun ./main -h for help\x1b[0m" quantize: quantize.cpp ggml.o llama.o utils.o diff --git a/main.cpp b/main.cpp new file mode 100644 index 0000000000000..61fec449a1ba9 --- /dev/null +++ b/main.cpp @@ -0,0 +1,5 @@ +#include "run.h" + +int main(int argc, char ** argv) { + return run(argc, argv); +} diff --git a/run.cpp b/run.cpp index 4569ef2a11fbb..e0db769747f1a 100644 --- a/run.cpp +++ b/run.cpp @@ -154,7 +154,7 @@ void sigint_handler(int signo) { } #endif -int main(int argc, char ** argv) { +int run(int argc, char ** argv) { // has to be called once at the start of the program to init ggml stuff ggml_time_init(); diff --git a/run.h b/run.h new file mode 100644 index 0000000000000..4a490bb98ba53 --- /dev/null +++ b/run.h @@ -0,0 +1,3 @@ +#pragma once + +int run(int argc, char ** argv); From b7f1fa6d8c118c4c9977bf57874b0584b2618856 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 09:39:25 -0300 Subject: [PATCH 3/5] Move llama_context setup + perplexity back to main.cpp Signed-off-by: Thiago Padilha --- main.cpp | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- run.cpp | 122 +----------------------------------------------------- run.h | 5 ++- 3 files changed, 128 insertions(+), 123 deletions(-) diff --git a/main.cpp b/main.cpp index 61fec449a1ba9..8ce9af8c385c5 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,127 @@ #include "run.h" +#include "ggml.h" + + +std::vector softmax(const std::vector& logits) { + std::vector probs(logits.size()); + float max_logit = logits[0]; + for (float v : logits) max_logit = std::max(max_logit, v); + double sum_exp = 0.0; + for (size_t i = 0; i < logits.size(); i++) { + // Subtract the maximum logit value from the current logit value for numerical stability + float logit = logits[i] - max_logit; + double exp_logit = std::exp(logit); + sum_exp += exp_logit; + probs[i] = exp_logit; + } + for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + return probs; +} + +void perplexity(llama_context * ctx, const gpt_params & params) { + // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research + // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` + // Output: `perplexity: 13.5106 [114/114]` + auto tokens = ::llama_tokenize(ctx, params.prompt, true); + + int count = 0; + double nll = 0.0; + int seq_count = tokens.size() / params.n_ctx; + + fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); + + for (int i = 0; i < seq_count; ++i) { + int start = i * params.n_ctx; + int end = start + params.n_ctx - 1; + std::vector embd(tokens.begin() + start, tokens.begin() + end); + auto start_t = std::chrono::high_resolution_clock::now(); + if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + auto end_t = std::chrono::high_resolution_clock::now(); + if (i == 0) { + double seconds = std::chrono::duration(end_t - start_t).count(); + printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); + } + // We get the logits for all the tokens in the context window (params.n_ctx) + // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, + // calculate the perplexity over the last half the window (so the model always has + // some context to predict the token). + // + // We rely on the fact that attention in the forward pass only looks at previous + // tokens here, so the logits returned for each token are an accurate representation + // of what the model would have predicted at that point. + // + // Example, we have a context window of 512, we will compute perplexity for each of the + // last 256 tokens. Then, we split the input up into context window size chunks to + // process the entire prompt. + + auto logits = llama_get_logits(ctx); + for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) { + // Calculate probability of next token, given the previous ones. + int n_vocab = llama_n_vocab(ctx); + std::vector tok_logits( + logits + j * n_vocab, + logits + (j + 1) * n_vocab); + double prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); + ++count; + } + // perplexity is e^(average negative log-likelihood) + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + fflush(stdout); + } + printf("\n"); +} int main(int argc, char ** argv) { - return run(argc, argv); + // has to be called once at the start of the program to init ggml stuff + ggml_time_init(); + + gpt_params params; + params.model = "models/llama-7B/ggml-model.bin"; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.n_ctx > 2048) { + fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" + "expect poor results\n", __func__, params.n_ctx); + } + + llama_context * ctx; + + // load the model + { + auto lparams = llama_context_default_params(); + + lparams.n_ctx = params.n_ctx; + lparams.n_parts = params.n_parts; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.logits_all = params.perplexity; + + ctx = llama_init_from_file(params.model.c_str(), lparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return 1; + } + } + + // print system information + { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); + } + + if (params.perplexity) { + perplexity(ctx, params); + exit(0); + } + + return run(ctx, params); } diff --git a/run.cpp b/run.cpp index e0db769747f1a..7b0543732b946 100644 --- a/run.cpp +++ b/run.cpp @@ -1,5 +1,4 @@ #include "utils.h" -#include "ggml.h" #include "llama.h" #include @@ -65,79 +64,6 @@ void set_console_state(console_state new_st) } } -std::vector softmax(const std::vector& logits) { - std::vector probs(logits.size()); - float max_logit = logits[0]; - for (float v : logits) max_logit = std::max(max_logit, v); - double sum_exp = 0.0; - for (size_t i = 0; i < logits.size(); i++) { - // Subtract the maximum logit value from the current logit value for numerical stability - float logit = logits[i] - max_logit; - double exp_logit = std::exp(logit); - sum_exp += exp_logit; - probs[i] = exp_logit; - } - for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; - return probs; -} - -void perplexity(llama_context * ctx, const gpt_params & params) { - // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research - // Run `./main --perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` - // Output: `perplexity: 13.5106 [114/114]` - auto tokens = ::llama_tokenize(ctx, params.prompt, true); - - int count = 0; - double nll = 0.0; - int seq_count = tokens.size() / params.n_ctx; - - fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); - - for (int i = 0; i < seq_count; ++i) { - int start = i * params.n_ctx; - int end = start + params.n_ctx - 1; - std::vector embd(tokens.begin() + start, tokens.begin() + end); - auto start_t = std::chrono::high_resolution_clock::now(); - if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return; - } - auto end_t = std::chrono::high_resolution_clock::now(); - if (i == 0) { - double seconds = std::chrono::duration(end_t - start_t).count(); - printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); - } - // We get the logits for all the tokens in the context window (params.n_ctx) - // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half the window (so the model always has - // some context to predict the token). - // - // We rely on the fact that attention in the forward pass only looks at previous - // tokens here, so the logits returned for each token are an accurate representation - // of what the model would have predicted at that point. - // - // Example, we have a context window of 512, we will compute perplexity for each of the - // last 256 tokens. Then, we split the input up into context window size chunks to - // process the entire prompt. - - auto logits = llama_get_logits(ctx); - for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) { - // Calculate probability of next token, given the previous ones. - int n_vocab = llama_n_vocab(ctx); - std::vector tok_logits( - logits + j * n_vocab, - logits + (j + 1) * n_vocab); - double prob = softmax(tok_logits)[tokens[start + j + 1]]; - nll += -std::log(prob); - ++count; - } - // perplexity is e^(average negative log-likelihood) - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); - fflush(stdout); - } - printf("\n"); -} - static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -154,21 +80,7 @@ void sigint_handler(int signo) { } #endif -int run(int argc, char ** argv) { - // has to be called once at the start of the program to init ggml stuff - ggml_time_init(); - - gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; - - if (gpt_params_parse(argc, argv, params) == false) { - return 1; - } - - if (params.n_ctx > 2048) { - fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" - "expect poor results\n", __func__, params.n_ctx); - } +int run(llama_context * ctx, gpt_params params) { if (params.seed <= 0) { params.seed = time(NULL); @@ -188,33 +100,6 @@ int run(int argc, char ** argv) { // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; - llama_context * ctx; - - // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.logits_all = params.perplexity; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } - } - - // print system information - { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); - } - // determine the required inference memory per token: // TODO: better way to do that { @@ -222,11 +107,6 @@ int run(int argc, char ** argv) { llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); } - if (params.perplexity) { - perplexity(ctx, params); - exit(0); - } - int n_past = 0; // Add a space in front of the first character to match OG llama tokenizer behavior diff --git a/run.h b/run.h index 4a490bb98ba53..3603396dadb09 100644 --- a/run.h +++ b/run.h @@ -1,3 +1,6 @@ #pragma once -int run(int argc, char ** argv); +#include "llama.h" +#include "utils.h" + +int run(llama_context * ctx, gpt_params params); From bf44faa0ee0f8f2ff51f51e566446973af7d5c07 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 09:55:45 -0300 Subject: [PATCH 4/5] Remove direct access to std streams from "run" The goal is to allow running "run" while connected to other streams, such as TCP sockets. Signed-off-by: Thiago Padilha --- main.cpp | 4 +++- run.cpp | 60 ++++++++++++++++++++++++++++++-------------------------- run.h | 6 +++++- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/main.cpp b/main.cpp index 8ce9af8c385c5..0044025e95b49 100644 --- a/main.cpp +++ b/main.cpp @@ -1,6 +1,8 @@ #include "run.h" #include "ggml.h" +#include + std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); @@ -123,5 +125,5 @@ int main(int argc, char ** argv) { exit(0); } - return run(ctx, params); + return run(ctx, params, std::cin, stdout, stderr); } diff --git a/run.cpp b/run.cpp index 7b0543732b946..ab430eb9291d8 100644 --- a/run.cpp +++ b/run.cpp @@ -44,7 +44,7 @@ enum console_state { static console_state con_st = CONSOLE_STATE_DEFAULT; static bool con_use_color = false; -void set_console_state(console_state new_st) +void set_console_state(FILE *stream, console_state new_st) { if (!con_use_color) return; // only emit color code if state changed @@ -52,13 +52,13 @@ void set_console_state(console_state new_st) con_st = new_st; switch(con_st) { case CONSOLE_STATE_DEFAULT: - printf(ANSI_COLOR_RESET); + fprintf(stream, ANSI_COLOR_RESET); return; case CONSOLE_STATE_PROMPT: - printf(ANSI_COLOR_YELLOW); + fprintf(stream, ANSI_COLOR_YELLOW); return; case CONSOLE_STATE_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); + fprintf(stream, ANSI_BOLD ANSI_COLOR_GREEN); return; } } @@ -68,7 +68,7 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(stdout, CONSOLE_STATE_DEFAULT); printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { @@ -80,13 +80,17 @@ void sigint_handler(int signo) { } #endif -int run(llama_context * ctx, gpt_params params) { +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream) { if (params.seed <= 0) { params.seed = time(NULL); } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + fprintf(errstream, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.random_prompt) { @@ -138,13 +142,13 @@ int run(llama_context * ctx, gpt_params params) { params.interactive = true; } - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + fprintf(errstream, "\n"); + fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -156,16 +160,16 @@ int run(llama_context * ctx, gpt_params params) { signal(SIGINT, sigint_handler); #endif - fprintf(stderr, "%s: interactive mode on.\n", __func__); + fprintf(errstream, "%s: interactive mode on.\n", __func__); if(params.antiprompt.size()) { for (auto antiprompt : params.antiprompt) { - fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); + fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str()); } } } - fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - fprintf(stderr, "\n\n"); + fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(errstream, "\n\n"); std::vector embd; @@ -174,7 +178,7 @@ int run(llama_context * ctx, gpt_params params) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { - fprintf(stderr, "== Running in interactive mode. ==\n" + fprintf(errstream, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif @@ -199,13 +203,13 @@ int run(llama_context * ctx, gpt_params params) { } #endif // the first thing we will do is to output the prompt, so set color accordingly - set_console_state(CONSOLE_STATE_PROMPT); + set_console_state(outstream, CONSOLE_STATE_PROMPT); while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + fprintf(errstream, "%s : failed to eval\n", __func__); return 1; } } @@ -263,13 +267,13 @@ int run(llama_context * ctx, gpt_params params) { // display text if (!input_noecho) { for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); + fprintf(outstream, "%s", llama_token_to_str(ctx, id)); } - fflush(stdout); + fflush(outstream); } // reset color to default if we there is no pending user input if (!input_noecho && (int)embd_inp.size() == input_consumed) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -290,20 +294,20 @@ int run(llama_context * ctx, gpt_params params) { } if (is_interacting) { // potentially set color to indicate we are taking user input - set_console_state(CONSOLE_STATE_USER_INPUT); + set_console_state(outstream, CONSOLE_STATE_USER_INPUT); if (params.instruct) { input_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - printf("\n> "); + fprintf(outstream, "\n> "); } std::string buffer; std::string line; bool another_line = true; do { - std::getline(std::cin, line); + std::getline(instream, line); if (line.empty() || line.back() != '\\') { another_line = false; } else { @@ -313,7 +317,7 @@ int run(llama_context * ctx, gpt_params params) { } while (another_line); // done taking input, reset color - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); auto line_inp = ::llama_tokenize(ctx, buffer, false); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); @@ -334,7 +338,7 @@ int run(llama_context * ctx, gpt_params params) { if (params.interactive) { is_interacting = true; } else { - fprintf(stderr, " [end of text]\n"); + fprintf(errstream, " [end of text]\n"); break; } } @@ -354,7 +358,7 @@ int run(llama_context * ctx, gpt_params params) { llama_free(ctx); - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); return 0; } diff --git a/run.h b/run.h index 3603396dadb09..39c8e9f063dc1 100644 --- a/run.h +++ b/run.h @@ -3,4 +3,8 @@ #include "llama.h" #include "utils.h" -int run(llama_context * ctx, gpt_params params); +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream); From 3a0dcb39207a18ab3f8d825914d0c4359ae9736d Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 10:41:26 -0300 Subject: [PATCH 5/5] Implement server mode. This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `run` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha --- CMakeLists.txt | 4 + Makefile | 7 +- chat_tcp_client.sh | 45 +++++++++ chat_tcp_server.sh | 6 ++ main.cpp | 7 ++ tcp_server.cpp | 245 +++++++++++++++++++++++++++++++++++++++++++++ tcp_server.h | 7 ++ utils.cpp | 8 ++ utils.h | 4 + 9 files changed, 331 insertions(+), 2 deletions(-) create mode 100755 chat_tcp_client.sh create mode 100755 chat_tcp_server.sh create mode 100644 tcp_server.cpp create mode 100644 tcp_server.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4db24fbbb598c..d95d93f99c27c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,6 +244,10 @@ add_executable(main run.cpp) target_link_libraries(main PRIVATE llama ggml utils) +if(NOT WIN32) + target_sources(main PRIVATE tcp_server.cpp) +endif() + add_executable(quantize quantize.cpp) target_link_libraries(quantize PRIVATE llama ggml utils) diff --git a/Makefile b/Makefile index 2f11ea166e299..59400a8033f34 100644 --- a/Makefile +++ b/Makefile @@ -229,11 +229,14 @@ utils.o: utils.cpp utils.h run.o: run.cpp run.h $(CXX) $(CXXFLAGS) -c run.cpp -o run.o +tcp_server.o: tcp_server.cpp tcp_server.h + $(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o + clean: rm -f *.o main quantize -main: main.cpp ggml.o llama.o utils.o run.o - $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS) +main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o + $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS) @echo "\x1b[36mrun ./main -h for help\x1b[0m" quantize: quantize.cpp ggml.o llama.o utils.o diff --git a/chat_tcp_client.sh b/chat_tcp_client.sh new file mode 100755 index 0000000000000..f154ae57dc4a6 --- /dev/null +++ b/chat_tcp_client.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User:Hello, Bob. +Bob:Hello. How may I help you today? +User:Please tell me the largest city in Europe. +Bob:Sure. The largest city in Europe is Moscow, the capital of Russia. +User:"}" +RPROMPT="${RPROMPT:-"User:"}" +N_PREDICT="${N_PREDICT:-"4096"}" +REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}" +N_THREADS="${N_THREADS:-"4"}" + +# Open connection to the chat server +exec 3<>/dev/tcp/127.0.0.1/${PORT} + +# Pass the arguments. The protocol is really simple: +# 1. Pass the number of arguments followed by a linefeed +# 2. Pass the arguments, with each being followed by "0" +( +echo -en "12\n" +echo -en "-t\x00" +echo -en "$N_THREADS\x00" +echo -en "-n\x00" +echo -en "$N_PREDICT\x00" +echo -en "--repeat_penalty\x00" +echo -en "$REPEAT_PENALTY\x00" +echo -en "--color\x00" +echo -en "-i\x00" +echo -en "-r\x00" +echo -en "$RPROMPT\x00" +echo -en "-p\x00" +echo -en "$PROMPT\x00" +) >&3 + +trap exit TERM + +# When we have passed the arguments, start printing socket data to the screen. +# This is done in a background job because we also want to send data when +# running in interactive mode. +cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" & +cat >&3 +wait diff --git a/chat_tcp_server.sh b/chat_tcp_server.sh new file mode 100755 index 0000000000000..79320906d7b0b --- /dev/null +++ b/chat_tcp_server.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin} + +./main -l ${PORT} -m $MODEL diff --git a/main.cpp b/main.cpp index 0044025e95b49..975714f9382f7 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,6 @@ #include "run.h" #include "ggml.h" +#include "tcp_server.h" #include @@ -125,5 +126,11 @@ int main(int argc, char ** argv) { exit(0); } +#ifndef _WIN32 + if (params.listen_port != "") { + return listen_tcp(ctx, params); + } +#endif + return run(ctx, params, std::cin, stdout, stderr); } diff --git a/tcp_server.cpp b/tcp_server.cpp new file mode 100644 index 0000000000000..9077c1807de1a --- /dev/null +++ b/tcp_server.cpp @@ -0,0 +1,245 @@ +#include "tcp_server.h" +#include "llama.h" +#include "utils.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +class PosixStream : public std::istream { + public: + PosixStream(int fd) : std::istream(&buf), buf(fd) {} + ~PosixStream() { close(buf.get_fd()); } + + private: + class PosixStreamBuf : public std::streambuf { + public: + PosixStreamBuf(int fd) : fd(fd) {} + int get_fd() const { return fd; } + + protected: + virtual int_type underflow() { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE); + if (num_read <= 0) { + return traits_type::eof(); + } + + setg(buffer, buffer, buffer + num_read); + return traits_type::to_int_type(*gptr()); + } + + private: + static const int BUFFER_SIZE = 1024; + int fd; + char buffer[BUFFER_SIZE]; + }; + + PosixStreamBuf buf; +}; + +void die(const char *msg, ...) +{ + va_list ap; + + va_start(ap, msg); + vfprintf(stderr, msg, ap); + va_end(ap); + fputc('\n', stderr); + exit(1); +} + +static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) { + bool done = false; + uint8_t *buf = *param_buf; + size_t bufsize = *param_buf_size; + size_t bufpos = 0; + while (!done) { + if (bufpos == bufsize) { + bufsize += 1024; + buf = (uint8_t *)realloc(buf, bufsize); + if (!buf) { + die("failed to allocate memory"); + } + } + + int c = fgetc(instream); + if (c == EOF) { + die("unexpected EOF client socket"); + } + buf[bufpos++] = (uint8_t)c; + if (c == 0) { + // done reading argument + break; + } + } + *param_buf = buf; + *param_buf_size = bufsize; + return strdup((char *)buf); +} + +static int read_arguments(int argc, char **argv, FILE *instream) { + int i = 1; + size_t param_buf_size = 0; + uint8_t *param_buf = nullptr; + + for (i = 1; i < argc; i++) { + argv[i] = read_argument(¶m_buf, ¶m_buf_size, instream); + } + + free(param_buf); + return i; +} + +static int serve_model(llama_context * ctx, + gpt_params params, + int sock_fd) +{ + int argc; + char **argv; + FILE *instream = fdopen(sock_fd, "r"); + FILE *outstream = fdopen(sock_fd, "w"); + setvbuf(instream, NULL, _IONBF, 0); + + // start by reading the parameter count + if (fscanf(instream, "%d\n", &argc) != 1) { + fprintf(outstream, "Error: First line must be character count\n"); + fflush(outstream); + return 1; + } + + argc += 1; // add one extra argument to emulate the program command line + argv = (char **)malloc(argc * sizeof *argv); + argv[0] = nullptr; + if (read_arguments(argc, argv, instream) != argc) { + fprintf(outstream, "Error: Failed to read arguments\n"); + fflush(outstream); + } + + if (gpt_params_parse(argc, argv, params) == false) { + fprintf(outstream, "Error: Failed to parse parameters\n"); + fflush(outstream); + return 1; + } + + for (int i = 1; i < argc; i++) { + free(argv[i]); + } + free(argv); + + PosixStream tcp_instream(sock_fd); + + return run(ctx, params, tcp_instream, outstream, outstream); +} + +int listen_tcp(llama_context * ctx, gpt_params params) { + int listen_fd; + int status; + pid_t child; + struct addrinfo hints; + struct addrinfo *servinfo, *p; + int yes = 1; + + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + // This should only ever listen on a loopback address. Access from outside + // should be proxied via socat or similar software + status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo); + if (status) { + die("getaddrinfo error: %s", gai_strerror(status)); + } + + // bind to the first addrinfo we can from the getaddrinfo results + for (p = servinfo; p != NULL; p = p->ai_next) { + listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (listen_fd == -1) { + perror("server: socket"); + continue; + } + + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) { + die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno)); + } + + if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) { + struct sockaddr_in addr_in; + socklen_t addr_in_len = sizeof(addr_in); + memset(&addr_in, 0, addr_in_len); + getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len); + + printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port)); + break; + } + + close(listen_fd); + perror("server: bind"); + } + + freeaddrinfo(servinfo); + + if (p == NULL) { + die("failed to bind: %s", strerror(errno)); + } + + if (listen(listen_fd, 20)) { + die("listen error: %s", strerror(errno)); + } + // Don't track child processes, so ignore SIGCHLD to prevent zombies + signal(SIGCHLD, SIG_IGN); + + for (;;) { + struct sockaddr_in client_addr; + socklen_t client_addr_len = 0; + memset(&client_addr, 0, sizeof(client_addr)); + + int sock_fd = accept(listen_fd, + (struct sockaddr *)&client_addr, + &client_addr_len); + if (sock_fd < 0) { + fprintf(stderr, "accept error: %s\n", strerror(errno)); + break; + } + + child = fork(); + if (child == 0) { + // close the listen_fd since we won't use it in the child + close(listen_fd); + int ret = serve_model(ctx, params, sock_fd); + close(sock_fd); + return ret; + } else { + // close the client since we won't use it in the server + close(sock_fd); + sock_fd = 0; + } + } + close(listen_fd); + + // ignore SIGTERM since we'll send it to the group + signal(SIGTERM, SIG_IGN); + // tell children to exit + kill(0, SIGTERM); + // wait for children to terminate + wait(&status); + return 0; +} diff --git a/tcp_server.h b/tcp_server.h new file mode 100644 index 0000000000000..38d6ecc810026 --- /dev/null +++ b/tcp_server.h @@ -0,0 +1,7 @@ +#pragma once + +#include "utils.h" +#include "llama.h" +#include "run.h" + +int listen_tcp(llama_context * ctx, gpt_params params); diff --git a/utils.cpp b/utils.cpp index 1d5309c3a4ca3..78baf924c4b87 100644 --- a/utils.cpp +++ b/utils.cpp @@ -77,6 +77,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.ignore_eos = true; } else if (arg == "--n_parts") { params.n_parts = std::stoi(argv[++i]); +#ifndef _WIN32 + } else if (arg == "-l" || arg == "--listen") { + params.listen_port = argv[++i]; +#endif } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); @@ -125,6 +129,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); +#ifndef _WIN32 + fprintf(stderr, " -l PORT, --listen PORT\n"); + fprintf(stderr, " Run in TCP mode, listening on PORT\n"); +#endif fprintf(stderr, "\n"); } diff --git a/utils.h b/utils.h index b0de556c95370..487892b1258c2 100644 --- a/utils.h +++ b/utils.h @@ -42,6 +42,10 @@ struct gpt_params { bool instruct = false; // instruction mode (used for Alpaca models) bool ignore_eos = false; // do not stop generating after eos bool perplexity = false; // compute perplexity over the prompt + +#ifndef _WIN32 + std::string listen_port = ""; // TCP port for when running in server mode +#endif }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params);