Skip to content

Commit

Permalink
Show progress bar for prompt processing in web ui
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Dec 13, 2024
1 parent fe514ef commit 4158265
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 45 deletions.
2 changes: 1 addition & 1 deletion llamafile/flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ float FLAG_frequency_penalty = 0;
float FLAG_presence_penalty = 0;
float FLAG_temperature = .8;
float FLAG_top_p = .95;
int FLAG_batch = 2048;
int FLAG_batch = 256;
int FLAG_ctx_size = 8192;
int FLAG_flash_attn = false;
int FLAG_gpu = 0;
Expand Down
56 changes: 47 additions & 9 deletions llamafile/server/slot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ Slot::eval_token(int token)
}

int
Slot::eval_tokens(const std::vector<int>& tokens)
Slot::eval_tokens(const std::vector<int>& tokens,
const ProgressCallback& progress)
{
if (!ctx_)
return uninitialized;
Expand All @@ -177,7 +178,8 @@ Slot::eval_tokens(const std::vector<int>& tokens)
int used = ctx_used();
if (used + N > ctx_size())
return out_of_context;
std::vector<int> toks(tokens); // TODO(jart): is copying really needed?
std::vector<int> toks(tokens);
int processed = 0;
for (int i = 0; i < N; i += FLAG_batch) {
int n_eval = N - i;
if (n_eval > FLAG_batch)
Expand All @@ -191,12 +193,16 @@ Slot::eval_tokens(const std::vector<int>& tokens)
for (int j = 0; j < n_eval; ++j)
history_.emplace_back(toks[i + j]);
used += n_eval;
processed += n_eval;
if (progress)
progress(processed, N);
}
return N;
}

int
Slot::eval_image(const std::string_view& bytes)
Slot::eval_image(const std::string_view& bytes,
const ProgressCallback& progress)
{
if (!ctx_)
return uninitialized;
Expand All @@ -215,6 +221,7 @@ Slot::eval_image(const std::string_view& bytes)
llava_image_embed_free(image_embed);
return out_of_context;
}
int processed = 0;
int n_embd = llama_n_embd(llama_get_model(ctx_));
for (int i = 0; i < N; i += FLAG_batch) {
int n_eval = N - i;
Expand All @@ -229,39 +236,70 @@ Slot::eval_image(const std::string_view& bytes)
return decode_image_failed;
}
used += n_eval;
processed += n_eval;
if (progress)
progress(processed, N);
}
llava_image_embed_free(image_embed);
history_.emplace_back(new Image(bytes, N));
return N;
}

int
Slot::eval_atoms(const std::vector<Atom>& atoms)
Slot::eval_atoms(const std::vector<Atom>& atoms,
const ProgressCallback& progress)
{
int total_work = 0;
if (progress) {
for (const Atom& atom : atoms) {
if (atom.is_token()) {
total_work += 1;
} else if (atom.is_image()) {
llava_image_embed* image_embed = llava_image_embed_make_with_bytes(
clip_ctx_,
FLAG_threads_batch,
(const unsigned char*)atom.image().bytes().data(),
atom.image().bytes().size());
if (image_embed) {
total_work += image_embed->n_image_pos;
llava_image_embed_free(image_embed);
}
}
}
if (total_work > FLAG_batch)
progress(0, total_work);
}
int processed = 0;
auto wrap_progress = [&](int curr, int subtotal) {
if (progress)
progress(processed + curr, total_work);
};
int rc;
int token_count = 0;
std::vector<int> tokens;
for (const Atom& atom : atoms) {
if (atom.is_token()) {
tokens.emplace_back(atom.token());
} else if (atom.is_image()) {
if ((rc = eval_tokens(tokens)) < 0)
if ((rc = eval_tokens(tokens, wrap_progress)) < 0)
return rc;
token_count += rc;
processed += rc;
tokens.clear();
if ((rc = eval_image(atom.image().bytes())) < 0)
if ((rc = eval_image(atom.image().bytes(), wrap_progress)) < 0)
return rc;
token_count += rc;
processed += rc;
}
}
if ((rc = eval_tokens(tokens)) < 0)
if ((rc = eval_tokens(tokens, wrap_progress)) < 0)
return rc;
token_count += rc;
return token_count;
}

int
Slot::prefill(const std::vector<Atom>& atoms_)
Slot::prefill(const std::vector<Atom>& atoms_, const ProgressCallback& progress)
{
if (!ctx_)
return uninitialized;
Expand Down Expand Up @@ -295,7 +333,7 @@ Slot::prefill(const std::vector<Atom>& atoms_)
}
std::vector<Atom> new_atoms(atoms.begin() + reuse_atoms, atoms.end());
int rc;
if ((rc = eval_atoms(new_atoms)) < 0)
if ((rc = eval_atoms(new_atoms, progress)) < 0)
return rc;
int token_count = reuse_tokens + rc;
SLOG("prefilled %zu tokens (after removing %zu and reusing %zu)",
Expand Down
11 changes: 7 additions & 4 deletions llamafile/server/slot.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once
#include <cosmo.h>
#include <functional>
#include <string>
#include <vector>

Expand All @@ -29,6 +30,8 @@ struct clip_ctx;
namespace lf {
namespace server {

using ProgressCallback = std::function<void(int processed, int total)>;

struct Atom;
struct Image;

Expand Down Expand Up @@ -59,10 +62,10 @@ struct Slot
int ctx_used() const;
bool start();
int eval_token(int);
int eval_image(const std::string_view&);
int eval_tokens(const std::vector<int>&);
int eval_atoms(const std::vector<Atom>&);
int prefill(const std::vector<Atom>&);
int eval_tokens(const std::vector<int>&, const ProgressCallback& = nullptr);
int eval_image(const std::string_view&, const ProgressCallback& = nullptr);
int eval_atoms(const std::vector<Atom>&, const ProgressCallback& = nullptr);
int prefill(const std::vector<Atom>&, const ProgressCallback& = nullptr);
void tokenize(std::vector<Atom>*, std::string_view, bool);
void dump(std::string*);
};
Expand Down
38 changes: 30 additions & 8 deletions llamafile/server/v1_chat_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,17 +475,11 @@ Client::v1_chat_completions()
return send_error(500, "failed to create sampler");
defer_cleanup(cleanup_sampler, sampler);

// prefill time
int prompt_tokens = 0;
if ((prompt_tokens = slot_->prefill(state->atoms)) < 0) {
SLOG("slot prefill failed: %s", Slot::describe_error(prompt_tokens));
return send_error(500, Slot::describe_error(prompt_tokens));
}

// setup response json
response->json["id"] = generate_id();
response->json["object"] = "chat.completion";
response->json["model"] = params->model;
response->json["x_prefill_progress"] = nullptr;
response->json["system_fingerprint"] = slot_->system_fingerprint_;
Json& choice = response->json["choices"][0];
choice["index"] = 0;
Expand All @@ -500,7 +494,35 @@ Client::v1_chat_completions()
return false;
choice["delta"]["role"] = "assistant";
choice["delta"]["content"] = "";
response->json["created"] = timespec_real().tv_sec;
}

// prefill time
int prompt_tokens = 0;
if (params->stream) {
auto progress_callback = [&](int processed, int total) {
if (processed < total) {
response->json["x_prefill_progress"] =
static_cast<float>(processed) / total;
response->json["created"] = timespec_real().tv_sec;
response->content = make_event(response->json);
if (!send_response_chunk(response->content)) {
return; // Note: Can't properly handle error in callback
}
}
};
prompt_tokens = slot_->prefill(state->atoms, progress_callback);
} else {
prompt_tokens = slot_->prefill(state->atoms);
}

if (prompt_tokens < 0) {
SLOG("slot prefill failed: %s", Slot::describe_error(prompt_tokens));
return send_error(500, Slot::describe_error(prompt_tokens));
}

// initialize response
if (params->stream) {
response->json.getObject().erase("x_prefill_progress");
response->content = make_event(response->json);
choice.getObject().erase("delta");
if (!send_response_chunk(response->content))
Expand Down
46 changes: 46 additions & 0 deletions llamafile/server/www/chatbot.css
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ p {
flex: 1;
overflow-y: auto;
padding: 1rem;
position: relative;
}

ol,
Expand Down Expand Up @@ -223,6 +224,7 @@ ul li:first-child {
border-radius: 4px;
overflow-x: auto;
position: relative;
white-space: pre-wrap;
}

.message blockquote {
Expand Down Expand Up @@ -629,3 +631,47 @@ ul li:first-child {
border: 1px solid #999;
}
}

.prefill-progress {
position: absolute;
bottom: 0;
left: 0;
right: 0;
height: 4px;
background: #eee;
}

.prefill-progress .progress-bar {
height: 100%;
background: #0d6efd;
width: 0;
transition: width 0.2s ease-out;
}

#prefill-status {
position: sticky;
bottom: 0;
left: 0;
right: 0;
padding: 8px;
background: rgba(255, 255, 255, 0.9);
backdrop-filter: blur(4px);
display: none;
align-items: center;
gap: 10px;
}

.prefill-progress {
flex: 1;
height: 4px;
background: #eee;
border-radius: 2px;
overflow: hidden;
}

.prefill-progress .progress-bar {
height: 100%;
background: #0d6efd;
width: 0;
transition: width 0.2s ease-out;
}
Loading

0 comments on commit 4158265

Please sign in to comment.