diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ba153cb82dcf6..f61a2ea4b21ee 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -26,6 +27,9 @@ static console_state con_st; static bool is_interacting = false; +static bool is_command = false; + +llama_context * ctx; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { @@ -33,6 +37,8 @@ void sigint_handler(int signo) { printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { + llama_print_timings(ctx); + llama_free(ctx); is_interacting=true; } else { _exit(130); @@ -41,6 +47,54 @@ void sigint_handler(int signo) { } #endif +int command(std::string buffer, gpt_params ¶ms, const int n_ctx ) { + // check buffer's first 3 chars equal '???' to enter command mode. + if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return 0; + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); + std::istringstream command(buffer); + int j = 0; std::string test, arg, cmd; + while (command>>test) { + j++; + if ( j == 2 ) arg = test; + if ( j == 3 ) cmd = test; + } + if (cmd == "") { + printf("Please enter a command value.\n"); + return 1; + } + if (arg == "n_predict") { + params.n_predict = std::stoi(cmd); + } else if (arg == "top_k") { + params.top_k = std::stoi(cmd); + } else if (arg == "ctx_size") { + params.n_ctx = std::stoi(cmd); + } else if (arg == "top_p") { + params.top_p = std::stof(cmd); + } else if (arg == "temp") { + params.temp = std::stof(cmd); + } else if (arg == "repeat_last_n") { + params.repeat_last_n = std::stoi(cmd); + } else if (arg == "repeat_penalty") { + params.repeat_penalty = std::stof(cmd); + } else if (arg == "batch_size") { + params.n_batch = std::stoi(cmd); + params.n_batch = std::min(512, params.n_batch); + } else if (arg == "reverse-prompt") { + params.antiprompt.push_back(cmd); + } else if (arg == "keep") { + params.n_keep = std::stoi(cmd); + } else if (arg == "stats") { + llama_print_timings(ctx); + } else { + printf("Invalid command: %s\nValid options are:\n n_predict, top_k, ctx_size, top_p, temp, repeat_last_n, repeat_penalty, batch_size, reverse-prompt, keep, stats\n", arg.c_str()); + return 1; + } + printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", + params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + return 1; +} + int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -92,8 +146,6 @@ int main(int argc, char ** argv) { // params.prompt = R"(// this function checks if the number n is prime //bool is_prime(int n) {)"; - llama_context * ctx; - // load the model { auto lparams = llama_context_default_params(); @@ -336,6 +388,8 @@ int main(int argc, char ** argv) { // display text if (!input_noecho) { + // if a command was entered clear the output to stop printing of gibberish. + if (is_command == true) embd.clear(); for (auto id : embd) { printf("%s", llama_token_to_str(ctx, id)); } @@ -419,7 +473,14 @@ int main(int argc, char ** argv) { // Add tokens to embd only if the input buffer is non-empty // Entering a empty line lets the user pass control back if (buffer.length() > 1) { - + //check for commands + if (command(buffer, params, n_ctx) == 0) { + // this is not a command, run normally. + is_command = false; + } else { + // this was a command, so we need to stop anything more from printing. + is_command = true; + } // instruct mode: insert instruction prefix if (params.instruct && !is_antiprompt) { n_consumed = embd_inp.size();