@@ -36,6 +36,36 @@ extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHand
36
36
#define ANSI_COLOR_RESET " \x1b [0m"
37
37
#define ANSI_BOLD " \x1b [1m"
38
38
39
+ /* Keep track of current color of output, and emit ANSI code if it changes. */
40
+ enum console_state {
41
+ CONSOLE_STATE_DEFAULT=0 ,
42
+ CONSOLE_STATE_PROMPT,
43
+ CONSOLE_STATE_USER_INPUT
44
+ };
45
+
46
+ static console_state con_st = CONSOLE_STATE_DEFAULT;
47
+ static bool con_use_color = false ;
48
+
49
+ void set_console_state (console_state new_st)
50
+ {
51
+ if (!con_use_color) return ;
52
+ // only emit color code if state changed
53
+ if (new_st != con_st) {
54
+ con_st = new_st;
55
+ switch (con_st) {
56
+ case CONSOLE_STATE_DEFAULT:
57
+ printf (ANSI_COLOR_RESET);
58
+ return ;
59
+ case CONSOLE_STATE_PROMPT:
60
+ printf (ANSI_COLOR_YELLOW);
61
+ return ;
62
+ case CONSOLE_STATE_USER_INPUT:
63
+ printf (ANSI_BOLD ANSI_COLOR_GREEN);
64
+ return ;
65
+ }
66
+ }
67
+ }
68
+
39
69
static const int EOS_TOKEN_ID = 2 ;
40
70
41
71
// determine number of model parts based on the dimension
@@ -866,7 +896,7 @@ static bool is_interacting = false;
866
896
867
897
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
868
898
void sigint_handler (int signo) {
869
- printf (ANSI_COLOR_RESET );
899
+ set_console_state (CONSOLE_STATE_DEFAULT );
870
900
printf (" \n " ); // this also force flush stdout.
871
901
if (signo == SIGINT) {
872
902
if (!is_interacting) {
@@ -925,6 +955,10 @@ int main(int argc, char ** argv) {
925
955
params.prompt = gpt_random_prompt (rng);
926
956
}
927
957
958
+ // save choice to use color for later
959
+ // (note for later: this is a slightly awkward choice)
960
+ con_use_color = params.use_color ;
961
+
928
962
// params.prompt = R"(// this function checks if the number n is prime
929
963
// bool is_prime(int n) {)";
930
964
@@ -1040,18 +1074,18 @@ int main(int argc, char ** argv) {
1040
1074
1041
1075
int remaining_tokens = params.n_predict ;
1042
1076
1043
- // set the color for the prompt which will be output initially
1044
- if (params.use_color ) {
1045
1077
#if defined (_WIN32)
1078
+ if (params.use_color ) {
1046
1079
// Enable ANSI colors on Windows 10+
1047
1080
unsigned long dwMode = 0 ;
1048
1081
void * hConOut = GetStdHandle ((unsigned long )-11 ); // STD_OUTPUT_HANDLE (-11)
1049
1082
if (hConOut && hConOut != (void *)-1 && GetConsoleMode (hConOut, &dwMode) && !(dwMode & 0x4 )) {
1050
1083
SetConsoleMode (hConOut, dwMode | 0x4 ); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
1051
1084
}
1052
- #endif
1053
- printf (ANSI_COLOR_YELLOW);
1054
1085
}
1086
+ #endif
1087
+ // the first thing we will do is to output the prompt, so set color accordingly
1088
+ set_console_state (CONSOLE_STATE_PROMPT);
1055
1089
1056
1090
while (remaining_tokens > 0 || params.interactive ) {
1057
1091
// predict
@@ -1125,8 +1159,8 @@ int main(int argc, char ** argv) {
1125
1159
fflush (stdout);
1126
1160
}
1127
1161
// reset color to default if we there is no pending user input
1128
- if (!input_noecho && params. use_color && (int )embd_inp.size () == input_consumed) {
1129
- printf (ANSI_COLOR_RESET );
1162
+ if (!input_noecho && (int )embd_inp.size () == input_consumed) {
1163
+ set_console_state (CONSOLE_STATE_DEFAULT );
1130
1164
}
1131
1165
1132
1166
// in interactive mode, and not currently processing queued inputs;
@@ -1146,15 +1180,16 @@ int main(int argc, char ** argv) {
1146
1180
}
1147
1181
}
1148
1182
if (is_interacting) {
1183
+ // potentially set color to indicate we are taking user input
1184
+ set_console_state (CONSOLE_STATE_USER_INPUT);
1185
+
1149
1186
if (params.instruct ) {
1150
1187
input_consumed = embd_inp.size ();
1151
1188
embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
1152
1189
1153
1190
printf (" \n > " );
1154
1191
}
1155
1192
1156
- // currently being interactive
1157
- if (params.use_color ) printf (ANSI_BOLD ANSI_COLOR_GREEN);
1158
1193
std::string buffer;
1159
1194
std::string line;
1160
1195
bool another_line = true ;
@@ -1167,7 +1202,9 @@ int main(int argc, char ** argv) {
1167
1202
}
1168
1203
buffer += line + ' \n ' ; // Append the line to the result
1169
1204
} while (another_line);
1170
- if (params.use_color ) printf (ANSI_COLOR_RESET);
1205
+
1206
+ // done taking input, reset color
1207
+ set_console_state (CONSOLE_STATE_DEFAULT);
1171
1208
1172
1209
std::vector<llama_vocab::id> line_inp = ::llama_tokenize (vocab, buffer, false );
1173
1210
embd_inp.insert (embd_inp.end (), line_inp.begin (), line_inp.end ());
@@ -1218,9 +1255,7 @@ int main(int argc, char ** argv) {
1218
1255
1219
1256
ggml_free (model.ctx );
1220
1257
1221
- if (params.use_color ) {
1222
- printf (ANSI_COLOR_RESET);
1223
- }
1258
+ set_console_state (CONSOLE_STATE_DEFAULT);
1224
1259
1225
1260
return 0 ;
1226
1261
}
0 commit comments