From 4158265e6baa627586d043da68d45849ba993604 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Fri, 13 Dec 2024 05:36:28 -0800 Subject: [PATCH] Show progress bar for prompt processing in web ui --- llamafile/flags.cpp | 2 +- llamafile/server/slot.cpp | 56 ++++++++++++--- llamafile/server/slot.h | 11 +-- llamafile/server/v1_chat_completions.cpp | 38 +++++++--- llamafile/server/www/chatbot.css | 46 ++++++++++++ llamafile/server/www/chatbot.js | 92 ++++++++++++++++++------ llamafile/server/www/index.html | 5 ++ 7 files changed, 205 insertions(+), 45 deletions(-) diff --git a/llamafile/flags.cpp b/llamafile/flags.cpp index bcd1d2fc96..7bd7cee48e 100644 --- a/llamafile/flags.cpp +++ b/llamafile/flags.cpp @@ -69,7 +69,7 @@ float FLAG_frequency_penalty = 0; float FLAG_presence_penalty = 0; float FLAG_temperature = .8; float FLAG_top_p = .95; -int FLAG_batch = 2048; +int FLAG_batch = 256; int FLAG_ctx_size = 8192; int FLAG_flash_attn = false; int FLAG_gpu = 0; diff --git a/llamafile/server/slot.cpp b/llamafile/server/slot.cpp index 9bff95ab03..fee1ec80da 100644 --- a/llamafile/server/slot.cpp +++ b/llamafile/server/slot.cpp @@ -167,7 +167,8 @@ Slot::eval_token(int token) } int -Slot::eval_tokens(const std::vector& tokens) +Slot::eval_tokens(const std::vector& tokens, + const ProgressCallback& progress) { if (!ctx_) return uninitialized; @@ -177,7 +178,8 @@ Slot::eval_tokens(const std::vector& tokens) int used = ctx_used(); if (used + N > ctx_size()) return out_of_context; - std::vector toks(tokens); // TODO(jart): is copying really needed? + std::vector toks(tokens); + int processed = 0; for (int i = 0; i < N; i += FLAG_batch) { int n_eval = N - i; if (n_eval > FLAG_batch) @@ -191,12 +193,16 @@ Slot::eval_tokens(const std::vector& tokens) for (int j = 0; j < n_eval; ++j) history_.emplace_back(toks[i + j]); used += n_eval; + processed += n_eval; + if (progress) + progress(processed, N); } return N; } int -Slot::eval_image(const std::string_view& bytes) +Slot::eval_image(const std::string_view& bytes, + const ProgressCallback& progress) { if (!ctx_) return uninitialized; @@ -215,6 +221,7 @@ Slot::eval_image(const std::string_view& bytes) llava_image_embed_free(image_embed); return out_of_context; } + int processed = 0; int n_embd = llama_n_embd(llama_get_model(ctx_)); for (int i = 0; i < N; i += FLAG_batch) { int n_eval = N - i; @@ -229,6 +236,9 @@ Slot::eval_image(const std::string_view& bytes) return decode_image_failed; } used += n_eval; + processed += n_eval; + if (progress) + progress(processed, N); } llava_image_embed_free(image_embed); history_.emplace_back(new Image(bytes, N)); @@ -236,8 +246,34 @@ Slot::eval_image(const std::string_view& bytes) } int -Slot::eval_atoms(const std::vector& atoms) +Slot::eval_atoms(const std::vector& atoms, + const ProgressCallback& progress) { + int total_work = 0; + if (progress) { + for (const Atom& atom : atoms) { + if (atom.is_token()) { + total_work += 1; + } else if (atom.is_image()) { + llava_image_embed* image_embed = llava_image_embed_make_with_bytes( + clip_ctx_, + FLAG_threads_batch, + (const unsigned char*)atom.image().bytes().data(), + atom.image().bytes().size()); + if (image_embed) { + total_work += image_embed->n_image_pos; + llava_image_embed_free(image_embed); + } + } + } + if (total_work > FLAG_batch) + progress(0, total_work); + } + int processed = 0; + auto wrap_progress = [&](int curr, int subtotal) { + if (progress) + progress(processed + curr, total_work); + }; int rc; int token_count = 0; std::vector tokens; @@ -245,23 +281,25 @@ Slot::eval_atoms(const std::vector& atoms) if (atom.is_token()) { tokens.emplace_back(atom.token()); } else if (atom.is_image()) { - if ((rc = eval_tokens(tokens)) < 0) + if ((rc = eval_tokens(tokens, wrap_progress)) < 0) return rc; token_count += rc; + processed += rc; tokens.clear(); - if ((rc = eval_image(atom.image().bytes())) < 0) + if ((rc = eval_image(atom.image().bytes(), wrap_progress)) < 0) return rc; token_count += rc; + processed += rc; } } - if ((rc = eval_tokens(tokens)) < 0) + if ((rc = eval_tokens(tokens, wrap_progress)) < 0) return rc; token_count += rc; return token_count; } int -Slot::prefill(const std::vector& atoms_) +Slot::prefill(const std::vector& atoms_, const ProgressCallback& progress) { if (!ctx_) return uninitialized; @@ -295,7 +333,7 @@ Slot::prefill(const std::vector& atoms_) } std::vector new_atoms(atoms.begin() + reuse_atoms, atoms.end()); int rc; - if ((rc = eval_atoms(new_atoms)) < 0) + if ((rc = eval_atoms(new_atoms, progress)) < 0) return rc; int token_count = reuse_tokens + rc; SLOG("prefilled %zu tokens (after removing %zu and reusing %zu)", diff --git a/llamafile/server/slot.h b/llamafile/server/slot.h index c6d3d17228..9fe26afb8c 100644 --- a/llamafile/server/slot.h +++ b/llamafile/server/slot.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -29,6 +30,8 @@ struct clip_ctx; namespace lf { namespace server { +using ProgressCallback = std::function; + struct Atom; struct Image; @@ -59,10 +62,10 @@ struct Slot int ctx_used() const; bool start(); int eval_token(int); - int eval_image(const std::string_view&); - int eval_tokens(const std::vector&); - int eval_atoms(const std::vector&); - int prefill(const std::vector&); + int eval_tokens(const std::vector&, const ProgressCallback& = nullptr); + int eval_image(const std::string_view&, const ProgressCallback& = nullptr); + int eval_atoms(const std::vector&, const ProgressCallback& = nullptr); + int prefill(const std::vector&, const ProgressCallback& = nullptr); void tokenize(std::vector*, std::string_view, bool); void dump(std::string*); }; diff --git a/llamafile/server/v1_chat_completions.cpp b/llamafile/server/v1_chat_completions.cpp index fb4e640497..8b5cf49052 100644 --- a/llamafile/server/v1_chat_completions.cpp +++ b/llamafile/server/v1_chat_completions.cpp @@ -475,17 +475,11 @@ Client::v1_chat_completions() return send_error(500, "failed to create sampler"); defer_cleanup(cleanup_sampler, sampler); - // prefill time - int prompt_tokens = 0; - if ((prompt_tokens = slot_->prefill(state->atoms)) < 0) { - SLOG("slot prefill failed: %s", Slot::describe_error(prompt_tokens)); - return send_error(500, Slot::describe_error(prompt_tokens)); - } - // setup response json response->json["id"] = generate_id(); response->json["object"] = "chat.completion"; response->json["model"] = params->model; + response->json["x_prefill_progress"] = nullptr; response->json["system_fingerprint"] = slot_->system_fingerprint_; Json& choice = response->json["choices"][0]; choice["index"] = 0; @@ -500,7 +494,35 @@ Client::v1_chat_completions() return false; choice["delta"]["role"] = "assistant"; choice["delta"]["content"] = ""; - response->json["created"] = timespec_real().tv_sec; + } + + // prefill time + int prompt_tokens = 0; + if (params->stream) { + auto progress_callback = [&](int processed, int total) { + if (processed < total) { + response->json["x_prefill_progress"] = + static_cast(processed) / total; + response->json["created"] = timespec_real().tv_sec; + response->content = make_event(response->json); + if (!send_response_chunk(response->content)) { + return; // Note: Can't properly handle error in callback + } + } + }; + prompt_tokens = slot_->prefill(state->atoms, progress_callback); + } else { + prompt_tokens = slot_->prefill(state->atoms); + } + + if (prompt_tokens < 0) { + SLOG("slot prefill failed: %s", Slot::describe_error(prompt_tokens)); + return send_error(500, Slot::describe_error(prompt_tokens)); + } + + // initialize response + if (params->stream) { + response->json.getObject().erase("x_prefill_progress"); response->content = make_event(response->json); choice.getObject().erase("delta"); if (!send_response_chunk(response->content)) diff --git a/llamafile/server/www/chatbot.css b/llamafile/server/www/chatbot.css index 9f50fea2fe..7b1e83a107 100644 --- a/llamafile/server/www/chatbot.css +++ b/llamafile/server/www/chatbot.css @@ -55,6 +55,7 @@ p { flex: 1; overflow-y: auto; padding: 1rem; + position: relative; } ol, @@ -223,6 +224,7 @@ ul li:first-child { border-radius: 4px; overflow-x: auto; position: relative; + white-space: pre-wrap; } .message blockquote { @@ -629,3 +631,47 @@ ul li:first-child { border: 1px solid #999; } } + +.prefill-progress { + position: absolute; + bottom: 0; + left: 0; + right: 0; + height: 4px; + background: #eee; +} + +.prefill-progress .progress-bar { + height: 100%; + background: #0d6efd; + width: 0; + transition: width 0.2s ease-out; +} + +#prefill-status { + position: sticky; + bottom: 0; + left: 0; + right: 0; + padding: 8px; + background: rgba(255, 255, 255, 0.9); + backdrop-filter: blur(4px); + display: none; + align-items: center; + gap: 10px; +} + +.prefill-progress { + flex: 1; + height: 4px; + background: #eee; + border-radius: 2px; + overflow: hidden; +} + +.prefill-progress .progress-bar { + height: 100%; + background: #0d6efd; + width: 0; + transition: width 0.2s ease-out; +} diff --git a/llamafile/server/www/chatbot.js b/llamafile/server/www/chatbot.js index e78635ede5..bb4a8b1182 100644 --- a/llamafile/server/www/chatbot.js +++ b/llamafile/server/www/chatbot.js @@ -104,11 +104,13 @@ async function handleChatStream(response) { const reader = response.body.getReader(); const decoder = new TextDecoder(); let buffer = ""; - let currentMessageElement = createMessageElement("", "assistant"); - chatMessages.appendChild(currentMessageElement); - let hdom = new HighlightDom(currentMessageElement); - const high = new RenderMarkdown(hdom); + let currentMessageElement = null; + let messageAppended = false; + let hdom = null; + let high = null; streamingMessageContent = []; + const prefillStatus = document.getElementById('prefill-status'); + const progressBar = prefillStatus.querySelector('.progress-bar'); try { while (true) { @@ -123,14 +125,35 @@ async function handleChatStream(response) { const line = lines[i].trim(); if (line.startsWith("data: ")) { const data = line.slice(6); - if (data === "[DONE]") + if (data === "[DONE]") { + prefillStatus.style.display = "none"; continue; + } try { const parsed = JSON.parse(data); const content = parsed.choices[0]?.delta?.content || ""; - streamingMessageContent.push(content); - high.feed(content); - scrollToBottom(); + + // handle prefill progress + if (parsed.x_prefill_progress !== undefined) { + prefillStatus.style.display = "flex"; + progressBar.style.width = `${parsed.x_prefill_progress * 100}%`; + } else { + prefillStatus.style.display = "none"; + } + + if (content && !messageAppended) { + currentMessageElement = createMessageElement("", "assistant"); + chatMessages.appendChild(currentMessageElement); + hdom = new HighlightDom(currentMessageElement); + high = new RenderMarkdown(hdom); + messageAppended = true; + } + + if (messageAppended && content) { + streamingMessageContent.push(content); + high.feed(content); + scrollToBottom(); + } } catch (e) { console.error("Error parsing JSON:", e); } @@ -145,7 +168,10 @@ async function handleChatStream(response) { console.error("Error reading stream:", error); } } finally { - high.flush(); + if (messageAppended) { + high.flush(); + } + prefillStatus.style.display = "none"; cleanupAfterMessage(); } } @@ -166,7 +192,8 @@ function fixUploads(str) { async function sendMessage() { const message = fixUploads(chatInput.value.trim()); - if (!message) return; + if (!message) + return; // disable input while processing chatInput.value = ""; @@ -327,13 +354,30 @@ async function onFile(file) { }; reader.readAsText(file); } else { - console.warn('Only image and text files are supported'); + alert('Only image and text files are supported'); return; } } +function checkSurroundingNewlines(text, pos) { + const beforeCaret = text.slice(0, pos); + const afterCaret = text.slice(pos); + const precedingNewlines = beforeCaret.match(/\n*$/)[0].length; + const followingNewlines = afterCaret.match(/^\n*/)[0].length; + return { precedingNewlines, followingNewlines }; +} + function insertText(elem, text) { const pos = elem.selectionStart; + const isCodeBlock = text.includes('```'); + + if (isCodeBlock) { + const { precedingNewlines, followingNewlines } = checkSurroundingNewlines(elem.value, pos); + const needsLeadingNewlines = pos > 0 && precedingNewlines < 2 ? '\n'.repeat(2 - precedingNewlines) : ''; + const needsTrailingNewlines = pos < elem.value.length && followingNewlines < 2 ? '\n'.repeat(2 - followingNewlines) : ''; + text = needsLeadingNewlines + text + needsTrailingNewlines; + } + elem.value = elem.value.slice(0, pos) + text + elem.value.slice(pos); const newPos = pos + text.length; elem.setSelectionRange(newPos, newPos); @@ -448,7 +492,7 @@ function updateSettingsDisplay(settings) { } } -function setupSettings() { +function setupSettings() { settingsButton.addEventListener("click", () => { settingsModal.style.display = "flex"; updateSettingsDisplay(loadSettings()); @@ -662,6 +706,17 @@ function setupCompletionsMode() { completionsInput.focus(); } +function onUploadButtonClick() { + fileUpload.click(); +} + +function onFileUploadChange(e) { + if (e.target.files[0]) { + onFile(e.target.files[0]); + e.target.value = ''; + } +} + async function chatbot() { flagz = await fetchFlagz(); updateModelInfo(); @@ -685,17 +740,8 @@ async function chatbot() { document.addEventListener("drop", onDragEnd); document.addEventListener("drop", onDrop); document.addEventListener("paste", onPaste); - - uploadButton.addEventListener("click", () => { - fileUpload.click(); - }); - - fileUpload.addEventListener("change", (e) => { - if (e.target.files[0]) { - onFile(e.target.files[0]); - e.target.value = ''; - } - }); + uploadButton.addEventListener("click", onUploadButtonClick); + fileUpload.addEventListener("change", onFileUploadChange); } chatbot(); diff --git a/llamafile/server/www/index.html b/llamafile/server/www/index.html index b02dca84b0..76c2e544cf 100644 --- a/llamafile/server/www/index.html +++ b/llamafile/server/www/index.html @@ -27,6 +27,11 @@

Loading... +
+
+
+
+