Skip to content

Commit 89d5d90

Browse files
authored
Fix color codes emitting mid-UTF8 code. (#312)
1 parent 16ffc01 commit 89d5d90

File tree

1 file changed

+48
-13
lines changed

1 file changed

+48
-13
lines changed

main.cpp

+48-13
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,36 @@ extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHand
3636
#define ANSI_COLOR_RESET "\x1b[0m"
3737
#define ANSI_BOLD "\x1b[1m"
3838

39+
/* Keep track of current color of output, and emit ANSI code if it changes. */
40+
enum console_state {
41+
CONSOLE_STATE_DEFAULT=0,
42+
CONSOLE_STATE_PROMPT,
43+
CONSOLE_STATE_USER_INPUT
44+
};
45+
46+
static console_state con_st = CONSOLE_STATE_DEFAULT;
47+
static bool con_use_color = false;
48+
49+
void set_console_state(console_state new_st)
50+
{
51+
if (!con_use_color) return;
52+
// only emit color code if state changed
53+
if (new_st != con_st) {
54+
con_st = new_st;
55+
switch(con_st) {
56+
case CONSOLE_STATE_DEFAULT:
57+
printf(ANSI_COLOR_RESET);
58+
return;
59+
case CONSOLE_STATE_PROMPT:
60+
printf(ANSI_COLOR_YELLOW);
61+
return;
62+
case CONSOLE_STATE_USER_INPUT:
63+
printf(ANSI_BOLD ANSI_COLOR_GREEN);
64+
return;
65+
}
66+
}
67+
}
68+
3969
static const int EOS_TOKEN_ID = 2;
4070

4171
// determine number of model parts based on the dimension
@@ -866,7 +896,7 @@ static bool is_interacting = false;
866896

867897
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
868898
void sigint_handler(int signo) {
869-
printf(ANSI_COLOR_RESET);
899+
set_console_state(CONSOLE_STATE_DEFAULT);
870900
printf("\n"); // this also force flush stdout.
871901
if (signo == SIGINT) {
872902
if (!is_interacting) {
@@ -925,6 +955,10 @@ int main(int argc, char ** argv) {
925955
params.prompt = gpt_random_prompt(rng);
926956
}
927957

958+
// save choice to use color for later
959+
// (note for later: this is a slightly awkward choice)
960+
con_use_color = params.use_color;
961+
928962
// params.prompt = R"(// this function checks if the number n is prime
929963
//bool is_prime(int n) {)";
930964

@@ -1040,18 +1074,18 @@ int main(int argc, char ** argv) {
10401074

10411075
int remaining_tokens = params.n_predict;
10421076

1043-
// set the color for the prompt which will be output initially
1044-
if (params.use_color) {
10451077
#if defined (_WIN32)
1078+
if (params.use_color) {
10461079
// Enable ANSI colors on Windows 10+
10471080
unsigned long dwMode = 0;
10481081
void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
10491082
if (hConOut && hConOut != (void*)-1 && GetConsoleMode(hConOut, &dwMode) && !(dwMode & 0x4)) {
10501083
SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
10511084
}
1052-
#endif
1053-
printf(ANSI_COLOR_YELLOW);
10541085
}
1086+
#endif
1087+
// the first thing we will do is to output the prompt, so set color accordingly
1088+
set_console_state(CONSOLE_STATE_PROMPT);
10551089

10561090
while (remaining_tokens > 0 || params.interactive) {
10571091
// predict
@@ -1125,8 +1159,8 @@ int main(int argc, char ** argv) {
11251159
fflush(stdout);
11261160
}
11271161
// reset color to default if we there is no pending user input
1128-
if (!input_noecho && params.use_color && (int)embd_inp.size() == input_consumed) {
1129-
printf(ANSI_COLOR_RESET);
1162+
if (!input_noecho && (int)embd_inp.size() == input_consumed) {
1163+
set_console_state(CONSOLE_STATE_DEFAULT);
11301164
}
11311165

11321166
// in interactive mode, and not currently processing queued inputs;
@@ -1146,15 +1180,16 @@ int main(int argc, char ** argv) {
11461180
}
11471181
}
11481182
if (is_interacting) {
1183+
// potentially set color to indicate we are taking user input
1184+
set_console_state(CONSOLE_STATE_USER_INPUT);
1185+
11491186
if (params.instruct) {
11501187
input_consumed = embd_inp.size();
11511188
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
11521189

11531190
printf("\n> ");
11541191
}
11551192

1156-
// currently being interactive
1157-
if (params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN);
11581193
std::string buffer;
11591194
std::string line;
11601195
bool another_line = true;
@@ -1167,7 +1202,9 @@ int main(int argc, char ** argv) {
11671202
}
11681203
buffer += line + '\n'; // Append the line to the result
11691204
} while (another_line);
1170-
if (params.use_color) printf(ANSI_COLOR_RESET);
1205+
1206+
// done taking input, reset color
1207+
set_console_state(CONSOLE_STATE_DEFAULT);
11711208

11721209
std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
11731210
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@@ -1218,9 +1255,7 @@ int main(int argc, char ** argv) {
12181255

12191256
ggml_free(model.ctx);
12201257

1221-
if (params.use_color) {
1222-
printf(ANSI_COLOR_RESET);
1223-
}
1258+
set_console_state(CONSOLE_STATE_DEFAULT);
12241259

12251260
return 0;
12261261
}

0 commit comments

Comments
 (0)