diff --git a/common/common.cmake b/common/common.cmake index bf32054daa80..a3d3b1a0c005 100644 --- a/common/common.cmake +++ b/common/common.cmake @@ -12,6 +12,7 @@ function(gpt4all_add_warning_options target) -Wformat=2 -Wmissing-include-dirs -Wstrict-overflow=2 + -Wsuggest-override -Wvla # errors -Werror=format-security diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel.h b/gpt4all-backend/include/gpt4all-backend/llmodel.h index ecb4a6ffe244..ed5c4878fc6c 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel.h @@ -124,9 +124,7 @@ class LLModel { }; struct PromptContext { - std::vector tokens; // current tokens in the context window int32_t n_past = 0; // number of tokens in past conversation - int32_t n_ctx = 0; // number of tokens possible in context window int32_t n_predict = 200; int32_t top_k = 40; float top_p = 0.9f; @@ -151,8 +149,8 @@ class LLModel { virtual bool isModelLoaded() const = 0; virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0; virtual size_t stateSize() const = 0; - virtual size_t saveState(std::span dest) const = 0; - virtual size_t restoreState(std::span src) = 0; + virtual size_t saveState(std::span stateOut, std::vector &inputTokensOut) const = 0; + virtual size_t restoreState(std::span state, std::span inputTokens) = 0; // This method requires the model to return true from supportsCompletion otherwise it will throw // an error @@ -210,6 +208,8 @@ class LLModel { void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } + virtual int32_t contextLength() const = 0; + protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions @@ -218,9 +218,15 @@ class LLModel { virtual std::string tokenToString(Token id) const = 0; virtual void initSampler(PromptContext &ctx) = 0; virtual Token sampleToken() const = 0; - virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) const = 0; + virtual bool evalTokens(PromptContext &ctx, std::span tokens) const = 0; virtual void shiftContext(PromptContext &promptCtx) = 0; - virtual int32_t contextLength() const = 0; + virtual int32_t inputLength() const = 0; + virtual void setTokenizeInputPosition(int32_t pos) = 0; + virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) + -> std::vector::const_iterator = 0; + virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0; + virtual void appendInputToken(PromptContext &ctx, Token tok) = 0; + virtual std::span inputTokens() const = 0; virtual const std::vector &endTokens() const = 0; virtual bool shouldAddBOS() const = 0; @@ -252,11 +258,13 @@ class LLModel { bool allowContextShift, PromptContext &promptCtx, std::vector embd_inp, - bool isResponse = false); + bool isResponse = false, + bool alwaysDecode = false); void generateResponse(std::function responseCallback, bool allowContextShift, PromptContext &promptCtx); +protected: Token m_tokenize_last_token = -1; // not serialized friend class LLMImplementation; diff --git a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h index 44a5568b9ecd..e9497d0fafd9 100644 --- a/gpt4all-backend/include/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/include/gpt4all-backend/llmodel_c.h @@ -23,6 +23,11 @@ extern "C" { */ typedef void *llmodel_model; +/** + * A token. + */ +typedef int32_t token_t; + /** * llmodel_prompt_context structure for holding the prompt context. * NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the @@ -30,10 +35,7 @@ typedef void *llmodel_model; * behavior. */ struct llmodel_prompt_context { - int32_t *tokens; // current tokens in the context window - size_t tokens_size; // the size of the raw tokens vector int32_t n_past; // number of tokens in past conversation - int32_t n_ctx; // number of tokens possible in context window int32_t n_predict; // number of tokens to predict int32_t top_k; // top k logits to sample from float top_p; // nucleus sampling probability threshold @@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model); * @param model A pointer to the llmodel_model instance. * @return the size in bytes of the internal state of the model */ -uint64_t llmodel_get_state_size(llmodel_model model); +uint64_t llmodel_state_get_size(llmodel_model model); /** - * Saves the internal state of the model to the specified destination address. + * Saves the internal state of the model. * NOTE: This state data is specific to the type of model you have created. * @param model A pointer to the llmodel_model instance. - * @param dest A pointer to the destination. - * @param size The size of the destination buffer. - * @return the number of bytes copied, or zero on error. + * @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes. + * @param state_size The size of the destination for the state. + * @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must + * be freed with llmodel_state_free_input_tokens. + * @param n_input_tokens Where to store the size of the token cache state. + * @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache + * size is set to zero. + */ +uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size, + token_t **input_tokens_out, uint64_t *n_input_tokens); + +/** + * Frees the temporary token cache buffer created by a call to llmodel_state_get_data(). + * @param input_tokens The token cache buffer. */ -uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size); +void llmodel_state_free_input_tokens(token_t *input_tokens); /** * Restores the internal state of the model using data from the specified address. * NOTE: This state data is specific to the type of model you have created. * @param model A pointer to the llmodel_model instance. - * @param src A pointer to the state data. - * @param size The size of the source data. + * @param state A pointer to the state data. + * @param state_size The size of the state data. + * @param input_tokens The token cache associated with the saved state. + * @param n_input_tokens The number of tokens in input_tokens. * @return The number of bytes read, or zero on error. */ -uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size); +uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size, + const token_t *input_tokens, uint64_t n_input_tokens); /** * Generate a response using the model. diff --git a/gpt4all-backend/src/llamamodel.cpp b/gpt4all-backend/src/llamamodel.cpp index d34024f432a5..453dbd972bd8 100644 --- a/gpt4all-backend/src/llamamodel.cpp +++ b/gpt4all-backend/src/llamamodel.cpp @@ -218,6 +218,7 @@ struct LLamaPrivate { int64_t n_threads = 0; std::vector end_tokens; const char *backend_name = nullptr; + std::vector inputTokens; llama_model *model = nullptr; llama_context *ctx = nullptr; @@ -501,14 +502,20 @@ size_t LLamaModel::stateSize() const return llama_state_get_size(d_ptr->ctx); } -size_t LLamaModel::saveState(std::span dest) const +size_t LLamaModel::saveState(std::span stateOut, std::vector &inputTokensOut) const { - return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size()); + size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size()); + if (bytesWritten) + inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end()); + return bytesWritten; } -size_t LLamaModel::restoreState(std::span src) +size_t LLamaModel::restoreState(std::span state, std::span inputTokens) { - return llama_state_set_data(d_ptr->ctx, src.data(), src.size()); + size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size()); + if (bytesRead) + d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end()); + return bytesRead; } std::vector LLamaModel::tokenize(std::string_view str, bool special) @@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1); } -bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const +bool LLamaModel::evalTokens(PromptContext &ctx, std::span tokens) const { llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1); @@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx) // erase up to n_ctx*contextErase tokens int n_keep = shouldAddBOS(); int n_past = promptCtx.n_past; - int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase)); + int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase)); assert(n_discard > 0); if (n_discard <= 0) @@ -638,8 +645,9 @@ void LLamaModel::shiftContext(PromptContext &promptCtx) llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard); llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard); - promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard); - promptCtx.n_past = promptCtx.tokens.size(); + auto &inp = d_ptr->inputTokens; + inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard); + promptCtx.n_past = inp.size(); } int32_t LLamaModel::contextLength() const @@ -647,6 +655,60 @@ int32_t LLamaModel::contextLength() const return llama_n_ctx(d_ptr->ctx); } +int32_t LLamaModel::inputLength() const +{ + return d_ptr->inputTokens.size(); +} + +void LLamaModel::setTokenizeInputPosition(int32_t pos) +{ + assert(pos >= 0); + m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized +} + +auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector &input) + -> std::vector::const_iterator +{ + assert(ctx.n_past >= 0); + auto pos = size_t(ctx.n_past); + if (pos > d_ptr->inputTokens.size()) { + std::ostringstream ss; + ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size(); + throw std::out_of_range(ss.str()); + } + + // find common prefix + auto cacheIt = d_ptr->inputTokens.begin(); + auto inputIt = input.begin(); + while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) { + ++cacheIt; ++inputIt; ++pos; + } + // tell the caller to ignore the tokens between [begin, inputIt) + return inputIt; +} + +void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos) +{ + auto &inp = d_ptr->inputTokens; + assert(pos >= 0); + assert(pos <= inp.size()); + // truncate token cache to end at the new n_past + if (pos < inp.size()) + inp.resize(pos); + ctx.n_past = pos; +} + +void LLamaModel::appendInputToken(PromptContext &ctx, Token tok) +{ + d_ptr->inputTokens.push_back(tok); + ctx.n_past += 1; +} + +auto LLamaModel::inputTokens() const -> std::span +{ + return d_ptr->inputTokens; +} + const std::vector &LLamaModel::endTokens() const { return d_ptr->end_tokens; diff --git a/gpt4all-backend/src/llamamodel_impl.h b/gpt4all-backend/src/llamamodel_impl.h index f7b3a47e0b6e..d6290a061316 100644 --- a/gpt4all-backend/src/llamamodel_impl.h +++ b/gpt4all-backend/src/llamamodel_impl.h @@ -28,8 +28,8 @@ class LLamaModel : public LLModel { bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override; size_t stateSize() const override; - size_t saveState(std::span dest) const override; - size_t restoreState(std::span src) override; + size_t saveState(std::span stateOut, std::vector &inputTokensOut) const override; + size_t restoreState(std::span state, std::span inputTokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; std::vector availableGPUDevices(size_t memoryRequired = 0) const override; @@ -48,10 +48,7 @@ class LLamaModel : public LLModel { void embed(const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; -private: - std::unique_ptr d_ptr; - bool m_supportsEmbedding = false; - bool m_supportsCompletion = false; + int32_t contextLength() const override; protected: std::vector tokenize(std::string_view str, bool special) override; @@ -59,9 +56,15 @@ class LLamaModel : public LLModel { std::string tokenToString(Token id) const override; void initSampler(PromptContext &ctx) override; Token sampleToken() const override; - bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; + bool evalTokens(PromptContext &ctx, std::span tokens) const override; void shiftContext(PromptContext &promptCtx) override; - int32_t contextLength() const override; + int32_t inputLength() const override; + void setTokenizeInputPosition(int32_t pos) override; + auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) + -> std::vector::const_iterator override; + void setModelInputPosition(PromptContext &ctx, int32_t pos) override; + void appendInputToken(PromptContext &ctx, Token tok) override; + std::span inputTokens() const override; const std::vector &endTokens() const override; bool shouldAddBOS() const override; int32_t maxContextLength(std::string const &modelPath) const override; @@ -70,6 +73,11 @@ class LLamaModel : public LLModel { void embedInternal(const std::vector &texts, float *embeddings, std::string prefix, int dimensionality, size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb, const EmbModelSpec *spec); + +private: + std::unique_ptr d_ptr; + bool m_supportsEmbedding = false; + bool m_supportsCompletion = false; }; #endif // LLAMAMODEL_H diff --git a/gpt4all-backend/src/llmodel_c.cpp b/gpt4all-backend/src/llmodel_c.cpp index c8c537f6fb33..068052665f39 100644 --- a/gpt4all-backend/src/llmodel_c.cpp +++ b/gpt4all-backend/src/llmodel_c.cpp @@ -14,6 +14,11 @@ #include #include #include +#include + +namespace ranges = std::ranges; + +static_assert(sizeof(token_t) == sizeof(LLModel::Token)); struct LLModelWrapper { LLModel *llModel = nullptr; @@ -85,22 +90,40 @@ bool llmodel_isModelLoaded(llmodel_model model) return wrapper->llModel->isModelLoaded(); } -uint64_t llmodel_get_state_size(llmodel_model model) +uint64_t llmodel_state_get_size(llmodel_model model) { auto *wrapper = static_cast(model); return wrapper->llModel->stateSize(); } -uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size) +uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size, + token_t **input_tokens_out, uint64_t *n_input_tokens) { auto *wrapper = static_cast(model); - return wrapper->llModel->saveState({dest, size_t(size)}); + std::vector inputTokens; + auto bytesWritten = wrapper->llModel->saveState({state_out, size_t(state_size)}, inputTokens); + if (bytesWritten) { + auto *buf = new LLModel::Token[inputTokens.size()]; + ranges::copy(inputTokens, buf); + *input_tokens_out = buf; + *n_input_tokens = uint64_t(inputTokens.size()); + } else { + *input_tokens_out = nullptr; + *n_input_tokens = 0; + } + return bytesWritten; } -uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, uint64_t size) +void llmodel_state_free_input_tokens(LLModel::Token *input_tokens) +{ + delete[] input_tokens; +} + +uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size, + const token_t *input_tokens, uint64_t n_input_tokens) { auto *wrapper = static_cast(model); - return wrapper->llModel->restoreState({src, size_t(size)}); + return wrapper->llModel->restoreState({state, size_t(state_size)}, {input_tokens, size_t(n_input_tokens)}); } void llmodel_prompt(llmodel_model model, const char *prompt, @@ -120,7 +143,6 @@ void llmodel_prompt(llmodel_model model, const char *prompt, // Copy the C prompt context wrapper->promptContext.n_past = ctx->n_past; - wrapper->promptContext.n_ctx = ctx->n_ctx; wrapper->promptContext.n_predict = ctx->n_predict; wrapper->promptContext.top_k = ctx->top_k; wrapper->promptContext.top_p = ctx->top_p; @@ -136,14 +158,8 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext, special, fake_reply ? std::make_optional(fake_reply) : std::nullopt); - // Update the C context by giving access to the wrappers raw pointers to std::vector data - // which involves no copies - ctx->tokens = wrapper->promptContext.tokens.data(); - ctx->tokens_size = wrapper->promptContext.tokens.size(); - // Update the rest of the C prompt context ctx->n_past = wrapper->promptContext.n_past; - ctx->n_ctx = wrapper->promptContext.n_ctx; ctx->n_predict = wrapper->promptContext.n_predict; ctx->top_k = wrapper->promptContext.top_k; ctx->top_p = wrapper->promptContext.top_p; diff --git a/gpt4all-backend/src/llmodel_shared.cpp b/gpt4all-backend/src/llmodel_shared.cpp index 3868f0d07ad3..ef046433a217 100644 --- a/gpt4all-backend/src/llmodel_shared.cpp +++ b/gpt4all-backend/src/llmodel_shared.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -66,19 +67,14 @@ void LLModel::prompt(const std::string &prompt, ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength(); throw std::out_of_range(ss.str()); } - if (promptCtx.n_past > promptCtx.tokens.size()) { + if (promptCtx.n_past > inputLength()) { std::ostringstream ss; - ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size(); + ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << inputLength(); throw std::out_of_range(ss.str()); } - promptCtx.n_ctx = contextLength(); promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); - if (promptCtx.n_past < promptCtx.tokens.size()) - promptCtx.tokens.resize(promptCtx.n_past); - m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized - // parse the prompt template std::vector placeholders; { @@ -90,6 +86,8 @@ void LLModel::prompt(const std::string &prompt, } } + setTokenizeInputPosition(promptCtx.n_past); + // tokenize the user prompt std::vector embd_inp; if (placeholders.empty()) { @@ -118,7 +116,8 @@ void LLModel::prompt(const std::string &prompt, } // decode the user prompt - if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp)) + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp, /*isResponse*/ false, + /*alwaysDecode*/ true)) return; // error // decode the assistant's reply, either generated or spoofed @@ -151,36 +150,67 @@ bool LLModel::decodePrompt(std::function promptCallback, bool allowContextShift, PromptContext &promptCtx, std::vector embd_inp, - bool isResponse) { - if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { + bool isResponse, + bool alwaysDecode) { + if ((int) embd_inp.size() > contextLength() - 4) { // FIXME: (Adam) We should find a way to bubble these strings to the UI level to allow for // translation responseCallback(-1, "Your message was too long and could not be processed. Please try again with something shorter."); std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << - " tokens and the context window is " << promptCtx.n_ctx << "!\n"; + " tokens and the context window is " << contextLength() << "!\n"; return false; } // FIXME(jared): There are mitigations for this situation, such as making room before // copying the prompt context, or restoring the KV cache when we restore the prompt // context. - if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) { + if (!allowContextShift && promptCtx.n_past + embd_inp.size() > contextLength()) { std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size() - << ", n_ctx=" << promptCtx.n_ctx << "\n"; + << ", n_ctx=" << contextLength() << "\n"; return false; } - // process the prompt in batches + // always decode something before generating, even if cached + if (alwaysDecode && embd_inp.empty()) { + auto cache = inputTokens(); + if (!promptCtx.n_past) + throw std::runtime_error("zero token prompt is not supported"); + assert(!cache.empty()); + embd_inp.push_back(cache.back()); + promptCtx.n_past--; + } + + // Find the greatest n_past where the beginning of embd_inp matches the end of the token cache, starting at the + // requested n_past. + // This is used to skip unnecessary work when the prompt shares a common prefix with the previous result. + auto embd_inp_start = computeModelInputPosition(promptCtx, embd_inp); + size_t start_offset = embd_inp_start - embd_inp.begin(); + + // always decode up to a full batch before generating, even if cached + if (alwaysDecode) + start_offset -= std::min(promptCtx.n_batch, int32_t(start_offset)); + + setModelInputPosition(promptCtx, promptCtx.n_past + start_offset); + + // execute the callback even for skipped tokens size_t i = 0; + for (; i < start_offset; i++) { + Token tok = embd_inp[i]; + bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); + if (!res) + return false; + } + + // process the prompt in batches while (i < embd_inp.size()) { size_t batch_end = std::min(i + promptCtx.n_batch, embd_inp.size()); - std::vector batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); + std::span batch(embd_inp.begin() + i, embd_inp.begin() + batch_end); // Check if the context has run out... - if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { + if (promptCtx.n_past + int32_t(batch.size()) > contextLength()) { assert(allowContextShift); shiftContext(promptCtx); - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); + assert(promptCtx.n_past + int32_t(batch.size()) <= contextLength()); } if (!evalTokens(promptCtx, batch)) { @@ -190,9 +220,8 @@ bool LLModel::decodePrompt(std::function promptCallback, size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) { - promptCtx.tokens.push_back(batch.at(t)); - promptCtx.n_past += 1; - Token tok = batch.at(t); + Token tok = batch[t]; + appendInputToken(promptCtx, tok); bool res = isResponse ? responseCallback(tok, tokenToString(tok)) : promptCallback(tok); if (!res) return false; @@ -232,8 +261,8 @@ void LLModel::generateResponse(std::function // Don't even start if there is no room if (!promptCtx.n_predict) return; - if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) { - std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx + if (!allowContextShift && promptCtx.n_past >= contextLength()) { + std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << contextLength() << "\n"; return; } @@ -254,23 +283,22 @@ void LLModel::generateResponse(std::function auto accept = [this, &promptCtx, &new_tok, allowContextShift]() -> bool { // Shift context if out of space - if (promptCtx.n_past >= promptCtx.n_ctx) { + if (promptCtx.n_past >= contextLength()) { (void)allowContextShift; assert(allowContextShift); shiftContext(promptCtx); - assert(promptCtx.n_past < promptCtx.n_ctx); + assert(promptCtx.n_past < contextLength()); } // Accept the token Token tok = std::exchange(new_tok, std::nullopt).value(); - if (!evalTokens(promptCtx, { tok })) { + if (!evalTokens(promptCtx, { &tok, 1 })) { // TODO(jared): raise an exception std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; return false; } - promptCtx.tokens.push_back(tok); - promptCtx.n_past += 1; + appendInputToken(promptCtx, tok); return true; }; @@ -309,9 +337,9 @@ void LLModel::generateResponse(std::function } // Optionally stop if the context will run out - if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) { + if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= contextLength()) { std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" - << promptCtx.n_ctx << "\n"; + << contextLength() << "\n"; stop = true; } @@ -357,16 +385,17 @@ void LLModel::generateResponse(std::function } } - auto &tokens = promptCtx.tokens; - if (tokens.size() < cachedTokens.size()) { + if (inputLength() < cachedTokens.size()) { /* This is theoretically possible if the longest stop sequence is greater than * n_ctx * contextErase tokens. */ throw std::runtime_error("shifted too much context, can't go back"); } - auto discard_start = tokens.end() - cachedTokens.size(); - assert(std::equal(discard_start, tokens.end(), cachedTokens.begin())); - tokens.erase(discard_start, tokens.end()); +#ifndef NDEBUG + auto inp = inputTokens(); + auto discard_start = inp.end() - cachedTokens.size(); + assert(std::equal(discard_start, inp.end(), cachedTokens.begin())); +#endif promptCtx.n_past -= cachedTokens.size(); } diff --git a/gpt4all-bindings/cli/app.py b/gpt4all-bindings/cli/app.py index e584a318038e..be6b5745877d 100755 --- a/gpt4all-bindings/cli/app.py +++ b/gpt4all-bindings/cli/app.py @@ -113,10 +113,7 @@ def _old_loop(gpt4all_instance): full_response = gpt4all_instance.chat_completion( MESSAGES, # preferential kwargs for chat ux - logits_size=0, - tokens_size=0, n_past=0, - n_ctx=0, n_predict=200, top_k=40, top_p=0.9, diff --git a/gpt4all-bindings/python/CHANGELOG.md b/gpt4all-bindings/python/CHANGELOG.md index 97ad1b7e3bbc..ec3ce4686b11 100644 --- a/gpt4all-bindings/python/CHANGELOG.md +++ b/gpt4all-bindings/python/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added - Warn on Windows if the Microsoft Visual C++ runtime libraries are not found ([#2920](https://github.com/nomic-ai/gpt4all/pull/2920)) +- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073)) ### Changed - Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998)) diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 8357731e9f42..136cf685aa2b 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -116,10 +116,7 @@ def load_llmodel_library(): class LLModelPromptContext(ctypes.Structure): _fields_ = [ - ("tokens", ctypes.POINTER(ctypes.c_int32)), - ("tokens_size", ctypes.c_size_t), ("n_past", ctypes.c_int32), - ("n_ctx", ctypes.c_int32), ("n_predict", ctypes.c_int32), ("top_k", ctypes.c_int32), ("top_p", ctypes.c_float), @@ -393,9 +390,7 @@ def _set_context( ): if self.context is None: context = LLModelPromptContext( - tokens_size=0, n_past=0, - n_ctx=0, n_predict=n_predict, top_k=top_k, top_p=top_p, diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index bf34d424f85d..ea57e14be773 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). ### Added - Add ability to attach text, markdown, and rst files to chat ([#3135](https://github.com/nomic-ai/gpt4all/pull/3135)) - Add feature to minimize to system tray (by [@bgallois](https://github.com/bgallois) in ([#3109](https://github.com/nomic-ai/gpt4all/pull/3109)) +- Basic cache for faster prefill when the input shares a prefix with previous context ([#3073](https://github.com/nomic-ai/gpt4all/pull/3073)) ### Changed - Implement Qt 6.8 compatibility ([#3121](https://github.com/nomic-ai/gpt4all/pull/3121)) diff --git a/gpt4all-chat/src/chatapi.cpp b/gpt4all-chat/src/chatapi.cpp index e8bcca709a93..27f64f0d6730 100644 --- a/gpt4all-chat/src/chatapi.cpp +++ b/gpt4all-chat/src/chatapi.cpp @@ -51,7 +51,6 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl) void ChatAPI::setThreadCount(int32_t n_threads) { Q_UNUSED(n_threads); - qt_noop(); } int32_t ChatAPI::threadCount() const @@ -68,24 +67,6 @@ bool ChatAPI::isModelLoaded() const return true; } -// All three of the state virtual functions are handled custom inside of chatllm save/restore -size_t ChatAPI::stateSize() const -{ - throw std::logic_error("not implemented"); -} - -size_t ChatAPI::saveState(std::span dest) const -{ - Q_UNUSED(dest); - throw std::logic_error("not implemented"); -} - -size_t ChatAPI::restoreState(std::span src) -{ - Q_UNUSED(src); - throw std::logic_error("not implemented"); -} - void ChatAPI::prompt(const std::string &prompt, const std::string &promptTemplate, std::function promptCallback, diff --git a/gpt4all-chat/src/chatapi.h b/gpt4all-chat/src/chatapi.h index 31d05310c6a7..f37a105d29f1 100644 --- a/gpt4all-chat/src/chatapi.h +++ b/gpt4all-chat/src/chatapi.h @@ -3,7 +3,7 @@ #include -#include +#include // IWYU pragma: keep #include #include #include @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include #include #include @@ -63,9 +65,15 @@ class ChatAPI : public QObject, public LLModel { bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override; bool isModelLoaded() const override; size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override; - size_t stateSize() const override; - size_t saveState(std::span dest) const override; - size_t restoreState(std::span src) override; + + // All three of the state virtual functions are handled custom inside of chatllm save/restore + size_t stateSize() const override + { throwNotImplemented(); } + size_t saveState(std::span stateOut, std::vector &inputTokensOut) const override + { Q_UNUSED(stateOut); Q_UNUSED(inputTokensOut); throwNotImplemented(); } + size_t restoreState(std::span state, std::span inputTokens) override + { Q_UNUSED(state); Q_UNUSED(inputTokens); throwNotImplemented(); } + void prompt(const std::string &prompt, const std::string &promptTemplate, std::function promptCallback, @@ -88,6 +96,10 @@ class ChatAPI : public QObject, public LLModel { bool callResponse(int32_t token, const std::string &string); + [[noreturn]] + int32_t contextLength() const override + { throwNotImplemented(); } + Q_SIGNALS: void request(const QString &apiKey, LLModel::PromptContext *ctx, @@ -98,60 +110,69 @@ class ChatAPI : public QObject, public LLModel { // them as they are only called from the default implementation of 'prompt' which we override and // completely replace + [[noreturn]] + static void throwNotImplemented() { throw std::logic_error("not implemented"); } + + [[noreturn]] std::vector tokenize(std::string_view str, bool special) override - { - (void)str; - (void)special; - throw std::logic_error("not implemented"); - } + { Q_UNUSED(str); Q_UNUSED(special); throwNotImplemented(); } + [[noreturn]] bool isSpecialToken(Token id) const override - { - (void)id; - throw std::logic_error("not implemented"); - } + { Q_UNUSED(id); throwNotImplemented(); } + [[noreturn]] std::string tokenToString(Token id) const override - { - (void)id; - throw std::logic_error("not implemented"); - } + { Q_UNUSED(id); throwNotImplemented(); } + [[noreturn]] void initSampler(PromptContext &ctx) override - { - (void)ctx; - throw std::logic_error("not implemented"); - } + { Q_UNUSED(ctx); throwNotImplemented(); } - Token sampleToken() const override { throw std::logic_error("not implemented"); } + [[noreturn]] + Token sampleToken() const override + { throwNotImplemented(); } - bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override - { - (void)ctx; - (void)tokens; - throw std::logic_error("not implemented"); - } + [[noreturn]] + bool evalTokens(PromptContext &ctx, std::span tokens) const override + { Q_UNUSED(ctx); Q_UNUSED(tokens); throwNotImplemented(); } + [[noreturn]] void shiftContext(PromptContext &promptCtx) override - { - (void)promptCtx; - throw std::logic_error("not implemented"); - } + { Q_UNUSED(promptCtx); throwNotImplemented(); } - int32_t contextLength() const override - { - throw std::logic_error("not implemented"); - } + [[noreturn]] + int32_t inputLength() const override + { throwNotImplemented(); } + [[noreturn]] + void setTokenizeInputPosition(int32_t pos) override + { Q_UNUSED(pos); throwNotImplemented(); } + + [[noreturn]] + auto computeModelInputPosition(PromptContext &ctx, const std::vector &input) + -> std::vector::const_iterator override + { Q_UNUSED(ctx); Q_UNUSED(input); throwNotImplemented(); } + + [[noreturn]] + void setModelInputPosition(PromptContext &ctx, int32_t pos) override + { Q_UNUSED(ctx); Q_UNUSED(pos); throwNotImplemented(); } + + [[noreturn]] + void appendInputToken(PromptContext &ctx, Token tok) override + { Q_UNUSED(ctx); Q_UNUSED(tok); throwNotImplemented(); } + + [[noreturn]] const std::vector &endTokens() const override - { - throw std::logic_error("not implemented"); - } + { throwNotImplemented(); } + [[noreturn]] bool shouldAddBOS() const override - { - throw std::logic_error("not implemented"); - } + { throwNotImplemented(); } + + [[noreturn]] + std::span inputTokens() const override + { throwNotImplemented(); } private: std::function m_responseCallback; diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index e693c03b60d1..e1cae8c8c1c9 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -404,7 +405,6 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro QString requestedDevice = MySettings::globalInstance()->device(); int n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo); - m_ctx.n_ctx = n_ctx; int ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo); std::string backend = "auto"; @@ -632,7 +632,6 @@ void ChatLLM::regenerateResponse() else m_ctx.n_past -= m_promptResponseTokens; m_ctx.n_past = std::max(0, m_ctx.n_past); - m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); m_promptResponseTokens = 0; m_promptTokens = 0; m_response = m_trimmedResponse = std::string(); @@ -1078,12 +1077,13 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) stream << responseLogits; } stream << m_ctx.n_past; + saveState(); if (version >= 7) { - stream << m_ctx.n_ctx; + stream << m_stateContextLength; } - stream << quint64(m_ctx.tokens.size()); - stream.writeRawData(reinterpret_cast(m_ctx.tokens.data()), m_ctx.tokens.size() * sizeof(int)); - saveState(); + stream << quint64(m_stateInputTokens.size()); + stream.writeRawData(reinterpret_cast(m_stateInputTokens.data()), + m_stateInputTokens.size() * sizeof(m_stateInputTokens[0])); QByteArray compressed = qCompress(m_state); stream << compressed; #if defined(DEBUG) @@ -1145,7 +1145,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, if (version >= 7) { uint32_t n_ctx; stream >> n_ctx; - if (!discardKV) m_ctx.n_ctx = n_ctx; + if (!discardKV) m_stateContextLength = n_ctx; } if (version < 9) { @@ -1157,10 +1157,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, quint64 tokensSize; stream >> tokensSize; if (!discardKV) { - m_ctx.tokens.resize(tokensSize); - stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + m_stateInputTokens.resize(tokensSize); + stream.readRawData(reinterpret_cast(m_stateInputTokens.data()), tokensSize * sizeof(m_stateInputTokens[0])); } else { - stream.skipRawData(tokensSize * sizeof(int)); + stream.skipRawData(tokensSize * sizeof(m_stateInputTokens[0])); } if (version >= 1) { @@ -1202,13 +1202,16 @@ void ChatLLM::saveState() #if defined(DEBUG) qDebug() << "saveState" << m_llmThread.objectName() << "size:" << m_state.size(); #endif - bool ok = m_llModelInfo.model->saveState({reinterpret_cast(m_state.data()), size_t(m_state.size())}); + bool ok = m_llModelInfo.model->saveState({reinterpret_cast(m_state.data()), size_t(m_state.size())}, + m_stateInputTokens); if (!ok) { // FIXME(jared): how badly does this situation break GPT4All? qWarning() << "ChatLLM failed to save LLModel state"; m_state.clear(); m_state.squeeze(); + m_stateContextLength = -1; } + m_stateContextLength = m_llModelInfo.model->contextLength(); } void ChatLLM::restoreState() @@ -1235,13 +1238,22 @@ void ChatLLM::restoreState() if (m_state.isEmpty()) return; - size_t bytesRead = m_llModelInfo.model->restoreState({reinterpret_cast(m_state.data()), size_t(m_state.size())}); - if (bytesRead) { - m_processedSystemPrompt = true; - m_pristineLoadedState = true; - } else { - qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)"; + if (m_llModelInfo.model->contextLength() != m_stateContextLength) { + qWarning() << "restoring state from text because of n_ctx mismatch (state" + << m_stateContextLength << "model" << m_llModelInfo.model->contextLength() << ")"; m_restoreStateFromText = true; + } else { + size_t bytesRead = m_llModelInfo.model->restoreState( + {reinterpret_cast(m_state.data()), size_t(m_state.size())}, + m_stateInputTokens + ); + if (!bytesRead) { + qWarning() << "restoring state from text because of error reading state (mismatch or corrupt data)"; + m_restoreStateFromText = true; + } else { + m_processedSystemPrompt = true; + m_pristineLoadedState = true; + } } // free local state copy unless unload is pending diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index 98348f5f891a..4b9936cb038c 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -9,7 +9,7 @@ #include #include #include -#include +#include // IWYU pragma: keep #include #include #include @@ -22,6 +22,7 @@ #include #include #include +#include using namespace Qt::Literals::StringLiterals; @@ -277,6 +278,8 @@ public Q_SLOTS: ModelInfo m_modelInfo; TokenTimer *m_timer; QByteArray m_state; + std::vector m_stateInputTokens; + int32_t m_stateContextLength = -1; QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded;