Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

main.cpp fixes, refactoring #571

Merged
merged 4 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,20 @@
#include <iterator>
#include <algorithm>

#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#include <alloca.h>
#endif
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <malloc.h> // using malloc.h with MSC/MINGW
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
#include <alloca.h>
#endif

#if defined (_WIN32)
#pragma comment(lib,"kernel32.lib")
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
#endif

bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
// determine sensible default number of threads.
Expand Down Expand Up @@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict);
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
Expand All @@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (ggml_mlock_supported()) {
fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
}
Expand Down Expand Up @@ -256,3 +265,47 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s

return res;
}

/* Keep track of current color of output, and emit ANSI code if it changes. */
void set_console_color(console_state & con_st, console_color_t color) {
if (con_st.use_color && con_st.color != color) {
switch(color) {
case CONSOLE_COLOR_DEFAULT:
printf(ANSI_COLOR_RESET);
break;
case CONSOLE_COLOR_PROMPT:
printf(ANSI_COLOR_YELLOW);
break;
case CONSOLE_COLOR_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
break;
}
con_st.color = color;
}
}

#if defined (_WIN32)
void win32_console_init(bool enable_color) {
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
hConOut = 0;
}
}
if (hConOut) {
// Enable ANSI colors on Windows 10+
if (enable_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
// Set console output codepage to UTF8
SetConsoleOutputCP(65001); // CP_UTF8
}
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF8
SetConsoleCP(65001); // CP_UTF8
}
}
#endif
30 changes: 30 additions & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,33 @@ std::string gpt_random_prompt(std::mt19937 & rng);
//

std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos);

//
// Console utils
//

#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"

enum console_color_t {
CONSOLE_COLOR_DEFAULT=0,
CONSOLE_COLOR_PROMPT,
CONSOLE_COLOR_USER_INPUT
};

struct console_state {
bool use_color = false;
console_color_t color = CONSOLE_COLOR_DEFAULT;
};

void set_console_color(console_state & con_st, console_color_t color);

#if defined (_WIN32)
void win32_console_init(bool enable_color);
#endif
164 changes: 53 additions & 111 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,13 @@
#include <signal.h>
#endif

#if defined (_WIN32)
#pragma comment(lib,"kernel32.lib")
extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
#endif

#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_GREEN "\x1b[32m"
#define ANSI_COLOR_YELLOW "\x1b[33m"
#define ANSI_COLOR_BLUE "\x1b[34m"
#define ANSI_COLOR_MAGENTA "\x1b[35m"
#define ANSI_COLOR_CYAN "\x1b[36m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_BOLD "\x1b[1m"

/* Keep track of current color of output, and emit ANSI code if it changes. */
enum console_state {
CONSOLE_STATE_DEFAULT=0,
CONSOLE_STATE_PROMPT,
CONSOLE_STATE_USER_INPUT
};

static console_state con_st = CONSOLE_STATE_DEFAULT;
static bool con_use_color = false;

void set_console_state(console_state new_st) {
if (!con_use_color) return;
// only emit color code if state changed
if (new_st != con_st) {
con_st = new_st;
switch(con_st) {
case CONSOLE_STATE_DEFAULT:
printf(ANSI_COLOR_RESET);
return;
case CONSOLE_STATE_PROMPT:
printf(ANSI_COLOR_YELLOW);
return;
case CONSOLE_STATE_USER_INPUT:
printf(ANSI_BOLD ANSI_COLOR_GREEN);
return;
}
}
}
static console_state con_st;

static bool is_interacting = false;

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
Expand All @@ -81,32 +36,6 @@ void sigint_handler(int signo) {
}
#endif

#if defined (_WIN32)
void win32_console_init(void) {
unsigned long dwMode = 0;
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
hConOut = 0;
}
}
if (hConOut) {
// Enable ANSI colors on Windows 10+
if (con_use_color && !(dwMode & 0x4)) {
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
}
// Set console output codepage to UTF8
SetConsoleOutputCP(65001); // CP_UTF8
}
void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
// Set console input codepage to UTF8
SetConsoleCP(65001); // CP_UTF8
}
}
#endif

int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
Expand All @@ -115,13 +44,12 @@ int main(int argc, char ** argv) {
return 1;
}


// save choice to use color for later
// (note for later: this is a slightly awkward choice)
con_use_color = params.use_color;
con_st.use_color = params.use_color;

#if defined (_WIN32)
win32_console_init();
win32_console_init(params.use_color);
#endif

if (params.perplexity) {
Expand Down Expand Up @@ -218,24 +146,23 @@ int main(int argc, char ** argv) {
return 1;
}

params.n_keep = std::min(params.n_keep, (int) embd_inp.size());
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
params.n_keep = (int)embd_inp.size();
}

// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);

// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
params.interactive = true;
params.interactive_start = true;
params.antiprompt.push_back("### Instruction:\n\n");
}

// enable interactive mode if reverse prompt is specified
if (params.antiprompt.size() != 0) {
params.interactive = true;
}

if (params.interactive_start) {
// enable interactive mode if reverse prompt or interactive start is specified
if (params.antiprompt.size() != 0 || params.interactive_start) {
params.interactive = true;
}

Expand Down Expand Up @@ -297,17 +224,18 @@ int main(int argc, char ** argv) {
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_start || params.instruct;
is_interacting = params.interactive_start;
}

bool input_noecho = false;
bool is_antiprompt = false;
bool input_noecho = false;

int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;

// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);
set_console_color(con_st, CONSOLE_COLOR_PROMPT);

std::vector<llama_token> embd;

Expand Down Expand Up @@ -408,36 +336,38 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {

// check for reverse prompt
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
if (params.antiprompt.size()) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}

// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
set_console_state(CONSOLE_STATE_USER_INPUT);
fflush(stdout);
break;
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
for (std::string & antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
is_antiprompt = true;
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
fflush(stdout);
break;
}
}
}

if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
set_console_state(CONSOLE_STATE_USER_INPUT);
set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);

if (params.instruct) {
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());

printf("\n> ");
}

Expand All @@ -463,16 +393,28 @@ int main(int argc, char ** argv) {
} while (another_line);

// done taking input, reset color
set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);

auto line_inp = ::llama_tokenize(ctx, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
// Add tokens to embd only if the input buffer is non-empty
// Entering a empty line lets the user pass control back
if (buffer.length() > 1) {

if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
// instruct mode: insert instruction prefix
if (params.instruct && !is_antiprompt) {
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}

n_remain -= line_inp.size();
auto line_inp = ::llama_tokenize(ctx, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

// instruct mode: insert response suffix
if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}

n_remain -= line_inp.size();
}

input_noecho = true; // do not echo this again
}
Expand Down Expand Up @@ -506,7 +448,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);

set_console_state(CONSOLE_STATE_DEFAULT);
set_console_color(con_st, CONSOLE_COLOR_DEFAULT);

return 0;
}