Skip to content

add command line mode #977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <iostream>
#include <string>
#include <vector>
#include <sstream>

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
Expand All @@ -26,13 +27,18 @@
static console_state con_st;

static bool is_interacting = false;
static bool is_command = false;

llama_context * ctx;

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
llama_print_timings(ctx);
llama_free(ctx);
is_interacting=true;
} else {
_exit(130);
Expand All @@ -41,6 +47,54 @@ void sigint_handler(int signo) {
}
#endif

int command(std::string buffer, gpt_params &params, const int n_ctx ) {
// check buffer's first 3 chars equal '???' to enter command mode.
if (buffer.length() <= 3 || strncmp(buffer.c_str(), "???", 3) != 0) return 0;
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
std::istringstream command(buffer);
int j = 0; std::string test, arg, cmd;
while (command>>test) {
j++;
if ( j == 2 ) arg = test;
if ( j == 3 ) cmd = test;
}
if (cmd == "") {
printf("Please enter a command value.\n");
return 1;
}
if (arg == "n_predict") {
params.n_predict = std::stoi(cmd);
} else if (arg == "top_k") {
params.top_k = std::stoi(cmd);
} else if (arg == "ctx_size") {
params.n_ctx = std::stoi(cmd);
} else if (arg == "top_p") {
params.top_p = std::stof(cmd);
} else if (arg == "temp") {
params.temp = std::stof(cmd);
} else if (arg == "repeat_last_n") {
params.repeat_last_n = std::stoi(cmd);
} else if (arg == "repeat_penalty") {
params.repeat_penalty = std::stof(cmd);
} else if (arg == "batch_size") {
params.n_batch = std::stoi(cmd);
params.n_batch = std::min(512, params.n_batch);
} else if (arg == "reverse-prompt") {
params.antiprompt.push_back(cmd);
} else if (arg == "keep") {
params.n_keep = std::stoi(cmd);
} else if (arg == "stats") {
llama_print_timings(ctx);
} else {
printf("Invalid command: %s\nValid options are:\n n_predict, top_k, ctx_size, top_p, temp, repeat_last_n, repeat_penalty, batch_size, reverse-prompt, keep, stats\n", arg.c_str());
return 1;
}
printf("sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
printf("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
return 1;
}

int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
Expand Down Expand Up @@ -92,8 +146,6 @@ int main(int argc, char ** argv) {
// params.prompt = R"(// this function checks if the number n is prime
//bool is_prime(int n) {)";

llama_context * ctx;

// load the model
{
auto lparams = llama_context_default_params();
Expand Down Expand Up @@ -336,6 +388,8 @@ int main(int argc, char ** argv) {

// display text
if (!input_noecho) {
// if a command was entered clear the output to stop printing of gibberish.
if (is_command == true) embd.clear();
for (auto id : embd) {
printf("%s", llama_token_to_str(ctx, id));
}
Expand Down Expand Up @@ -419,7 +473,14 @@ int main(int argc, char ** argv) {
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {

//check for commands
if (command(buffer, params, n_ctx) == 0) {
// this is not a command, run normally.
is_command = false;
} else {
// this was a command, so we need to stop anything more from printing.
is_command = true;
}
// instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) {
n_consumed = embd_inp.size();
Expand Down