Skip to content

Commit

Permalink
Show prompt loading progress in chatbot
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Oct 13, 2024
1 parent 726f6e8 commit 28e98b6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
76 changes: 55 additions & 21 deletions llamafile/chatbot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "llamafile/highlight.h"
#include <assert.h>
#include <cosmo.h>
#include <ctype.h>
Expand All @@ -29,6 +28,7 @@
#include "llama.cpp/common.h"
#include "llama.cpp/llama.h"
#include "llamafile/bestline.h"
#include "llamafile/highlight.h"
#include "llamafile/llamafile.h"

#define BOLD "\e[1m"
Expand Down Expand Up @@ -74,6 +74,21 @@ static std::string basename(const std::string_view path) {
}
}

__attribute__((format(printf, 1, 2))) static std::string format(const char *fmt, ...) {
va_list ap, ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = 512;
std::string res(size, '\0');
int need = vsnprintf(res.data(), size, fmt, ap);
res.resize(need + 1, '\0');
if (need + 1 > size)
vsnprintf(res.data(), need + 1, fmt, ap2);
va_end(ap2);
va_end(ap);
return res;
}

static void on_completion(const char *line, bestlineCompletions *comp) {
static const char *const kCompletions[] = {
"/context", //
Expand Down Expand Up @@ -134,6 +149,15 @@ static void print_logo(const char16_t *s) {
}
}

static void print_ephemeral(const std::string_view &description) {
fprintf(stderr, " " BRIGHT_BLACK "%.*s" UNFOREGROUND "\r", (int)description.size(),
description.data());
}

static void clear_ephemeral(void) {
fprintf(stderr, CLEAR_FORWARD);
}

static void die_out_of_context(void) {
fprintf(stderr,
"\n" BRIGHT_RED
Expand All @@ -145,7 +169,13 @@ static void die_out_of_context(void) {

static void eval_tokens(std::vector<llama_token> tokens, int n_batch) {
int N = (int)tokens.size();
if (n_past + N > llama_n_ctx(g_ctx)) {
n_past += N;
die_out_of_context();
}
for (int i = 0; i < N; i += n_batch) {
if (N > n_batch)
print_ephemeral(format("loading prompt %d%%...", (int)((double)i / N * 100)));
int n_eval = (int)tokens.size() - i;
if (n_eval > n_batch)
n_eval = n_batch;
Expand All @@ -161,17 +191,8 @@ static void eval_id(int id) {
eval_tokens(tokens, 1);
}

static void eval_string(const char *str, int n_batch, bool add_special, bool parse_special) {
std::string str2 = str;
eval_tokens(llama_tokenize(g_ctx, str2, add_special, parse_special), n_batch);
}

static void print_ephemeral(const char *description) {
fprintf(stderr, " " BRIGHT_BLACK "%s" UNFOREGROUND "\r", description);
}

static void clear_ephemeral(void) {
fprintf(stderr, CLEAR_FORWARD);
static void eval_string(const std::string &str, int n_batch, bool add_special, bool parse_special) {
eval_tokens(llama_tokenize(g_ctx, str, add_special, parse_special), n_batch);
}

int chatbot_main(int argc, char **argv) {
Expand All @@ -180,8 +201,12 @@ int chatbot_main(int argc, char **argv) {
log_disable();

gpt_params params;
if (!gpt_params_parse(argc, argv, params))
return 1;
params.n_batch = 512; // for better progress indication
params.sparams.temp = 0; // don't believe in randomness by default
if (!gpt_params_parse(argc, argv, params)) {
fprintf(stderr, "error: failed to parse flags\n");
exit(1);
}

print_logo(u"\n\
██╗ ██╗ █████╗ ███╗ ███╗ █████╗ ███████╗██╗██╗ ███████╗\n\
Expand All @@ -203,27 +228,36 @@ int chatbot_main(int argc, char **argv) {
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = llamafile_gpu_layers(35);
g_model = llama_load_model_from_file(params.model.c_str(), model_params);
if (g_model == NULL)
return 2;
if (g_model == NULL) {
clear_ephemeral();
fprintf(stderr, "%s: failed to load model\n", params.model.c_str());
exit(2);
}
if (!params.n_ctx)
params.n_ctx = llama_n_ctx_train(g_model);
if (params.n_ctx < params.n_batch)
params.n_batch = params.n_ctx;
clear_ephemeral();

print_ephemeral("initializing context...");
llama_context_params ctx_params = llama_context_params_from_gpt_params(params);
g_ctx = llama_new_context_with_model(g_model, ctx_params);
if (g_ctx == NULL)
return 3;
if (g_ctx == NULL) {
clear_ephemeral();
fprintf(stderr, "error: failed to initialize context\n");
exit(3);
}
clear_ephemeral();

if (params.prompt.empty())
params.prompt =
"A chat between a curious human and an artificial intelligence assistant. The "
"assistant gives helpful, detailed, and polite answers to the human's questions.";

print_ephemeral("loading prompt...");
bool add_bos = llama_should_add_bos_token(llama_get_model(g_ctx));
std::vector<llama_chat_msg> chat = {{"system", params.prompt}};
std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, false);
eval_string(msg.c_str(), params.n_batch, add_bos, true);
eval_string(msg, params.n_batch, add_bos, true);
clear_ephemeral();
printf("%s\n", params.special ? msg.c_str() : params.prompt.c_str());

Expand Down Expand Up @@ -254,7 +288,7 @@ int chatbot_main(int argc, char **argv) {
}
std::vector<llama_chat_msg> chat = {{"user", line}};
std::string msg = llama_chat_apply_template(g_model, params.chat_template, chat, true);
eval_string(msg.c_str(), params.n_batch, false, true);
eval_string(msg, params.n_batch, false, true);
while (!g_got_sigint) {
llama_token id = llama_sampling_sample(sampler, g_ctx, NULL);
llama_sampling_accept(sampler, g_ctx, id, true);
Expand Down
2 changes: 2 additions & 0 deletions llamafile/highlight_markdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ void HighlightMarkdown::feed(std::string *r, std::string_view input) {
if (c == '*') {
t_ = NORMAL;
*r += RESET;
} else {
t_ = STRONG;
}
break;

Expand Down

0 comments on commit 28e98b6

Please sign in to comment.