Skip to content

Commit

Permalink
Use the token cache to infer greater n_past and reuse results (#3073)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre authored Oct 31, 2024
1 parent 62cab69 commit f07e2e6
Show file tree
Hide file tree
Showing 15 changed files with 319 additions and 168 deletions.
1 change: 1 addition & 0 deletions common/common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ function(gpt4all_add_warning_options target)
-Wformat=2
-Wmissing-include-dirs
-Wstrict-overflow=2
-Wsuggest-override
-Wvla
# errors
-Werror=format-security
Expand Down
22 changes: 15 additions & 7 deletions gpt4all-backend/include/gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ class LLModel {
};

struct PromptContext {
std::vector<int32_t> tokens; // current tokens in the context window
int32_t n_past = 0; // number of tokens in past conversation
int32_t n_ctx = 0; // number of tokens possible in context window
int32_t n_predict = 200;
int32_t top_k = 40;
float top_p = 0.9f;
Expand All @@ -151,8 +149,8 @@ class LLModel {
virtual bool isModelLoaded() const = 0;
virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0;
virtual size_t stateSize() const = 0;
virtual size_t saveState(std::span<uint8_t> dest) const = 0;
virtual size_t restoreState(std::span<const uint8_t> src) = 0;
virtual size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const = 0;
virtual size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) = 0;

// This method requires the model to return true from supportsCompletion otherwise it will throw
// an error
Expand Down Expand Up @@ -210,6 +208,8 @@ class LLModel {

void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; }

virtual int32_t contextLength() const = 0;

protected:
// These are pure virtual because subclasses need to implement as the default implementation of
// 'prompt' above calls these functions
Expand All @@ -218,9 +218,15 @@ class LLModel {
virtual std::string tokenToString(Token id) const = 0;
virtual void initSampler(PromptContext &ctx) = 0;
virtual Token sampleToken() const = 0;
virtual bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const = 0;
virtual bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const = 0;
virtual void shiftContext(PromptContext &promptCtx) = 0;
virtual int32_t contextLength() const = 0;
virtual int32_t inputLength() const = 0;
virtual void setTokenizeInputPosition(int32_t pos) = 0;
virtual auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator = 0;
virtual void setModelInputPosition(PromptContext &ctx, int32_t pos) = 0;
virtual void appendInputToken(PromptContext &ctx, Token tok) = 0;
virtual std::span<const Token> inputTokens() const = 0;
virtual const std::vector<Token> &endTokens() const = 0;
virtual bool shouldAddBOS() const = 0;

Expand Down Expand Up @@ -252,11 +258,13 @@ class LLModel {
bool allowContextShift,
PromptContext &promptCtx,
std::vector<Token> embd_inp,
bool isResponse = false);
bool isResponse = false,
bool alwaysDecode = false);
void generateResponse(std::function<bool(int32_t, const std::string&)> responseCallback,
bool allowContextShift,
PromptContext &promptCtx);

protected:
Token m_tokenize_last_token = -1; // not serialized

friend class LLMImplementation;
Expand Down
40 changes: 28 additions & 12 deletions gpt4all-backend/include/gpt4all-backend/llmodel_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@ extern "C" {
*/
typedef void *llmodel_model;

/**
* A token.
*/
typedef int32_t token_t;

/**
* llmodel_prompt_context structure for holding the prompt context.
* NOTE: The implementation takes care of all the memory handling of the raw logits pointer and the
* raw tokens pointer. Attempting to resize them or modify them in any way can lead to undefined
* behavior.
*/
struct llmodel_prompt_context {
int32_t *tokens; // current tokens in the context window
size_t tokens_size; // the size of the raw tokens vector
int32_t n_past; // number of tokens in past conversation
int32_t n_ctx; // number of tokens possible in context window
int32_t n_predict; // number of tokens to predict
int32_t top_k; // top k logits to sample from
float top_p; // nucleus sampling probability threshold
Expand Down Expand Up @@ -141,27 +143,41 @@ bool llmodel_isModelLoaded(llmodel_model model);
* @param model A pointer to the llmodel_model instance.
* @return the size in bytes of the internal state of the model
*/
uint64_t llmodel_get_state_size(llmodel_model model);
uint64_t llmodel_state_get_size(llmodel_model model);

/**
* Saves the internal state of the model to the specified destination address.
* Saves the internal state of the model.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param dest A pointer to the destination.
* @param size The size of the destination buffer.
* @return the number of bytes copied, or zero on error.
* @param state Where to store the state. This must be a buffer of at least llmodel_state_get_size() bytes.
* @param state_size The size of the destination for the state.
* @param input_tokens_out Where to store the address of the token cache state. This is dynamically allocated and must
* be freed with llmodel_state_free_input_tokens.
* @param n_input_tokens Where to store the size of the token cache state.
* @return The number of bytes copied. On error, zero is returned, the token cache is set to NULL, and the token cache
* size is set to zero.
*/
uint64_t llmodel_state_get_data(llmodel_model model, uint8_t *state_out, uint64_t state_size,
token_t **input_tokens_out, uint64_t *n_input_tokens);

/**
* Frees the temporary token cache buffer created by a call to llmodel_state_get_data().
* @param input_tokens The token cache buffer.
*/
uint64_t llmodel_save_state_data(llmodel_model model, uint8_t *dest, uint64_t size);
void llmodel_state_free_input_tokens(token_t *input_tokens);

/**
* Restores the internal state of the model using data from the specified address.
* NOTE: This state data is specific to the type of model you have created.
* @param model A pointer to the llmodel_model instance.
* @param src A pointer to the state data.
* @param size The size of the source data.
* @param state A pointer to the state data.
* @param state_size The size of the state data.
* @param input_tokens The token cache associated with the saved state.
* @param n_input_tokens The number of tokens in input_tokens.
* @return The number of bytes read, or zero on error.
*/
uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src, size_t size);
uint64_t llmodel_state_set_data(llmodel_model model, const uint8_t *state, uint64_t state_size,
const token_t *input_tokens, uint64_t n_input_tokens);

/**
* Generate a response using the model.
Expand Down
78 changes: 70 additions & 8 deletions gpt4all-backend/src/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct LLamaPrivate {
int64_t n_threads = 0;
std::vector<LLModel::Token> end_tokens;
const char *backend_name = nullptr;
std::vector<LLModel::Token> inputTokens;

llama_model *model = nullptr;
llama_context *ctx = nullptr;
Expand Down Expand Up @@ -501,14 +502,20 @@ size_t LLamaModel::stateSize() const
return llama_state_get_size(d_ptr->ctx);
}

size_t LLamaModel::saveState(std::span<uint8_t> dest) const
size_t LLamaModel::saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const
{
return llama_state_get_data(d_ptr->ctx, dest.data(), dest.size());
size_t bytesWritten = llama_state_get_data(d_ptr->ctx, stateOut.data(), stateOut.size());
if (bytesWritten)
inputTokensOut.assign(d_ptr->inputTokens.begin(), d_ptr->inputTokens.end());
return bytesWritten;
}

size_t LLamaModel::restoreState(std::span<const uint8_t> src)
size_t LLamaModel::restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens)
{
return llama_state_set_data(d_ptr->ctx, src.data(), src.size());
size_t bytesRead = llama_state_set_data(d_ptr->ctx, state.data(), state.size());
if (bytesRead)
d_ptr->inputTokens.assign(inputTokens.begin(), inputTokens.end());
return bytesRead;
}

std::vector<LLModel::Token> LLamaModel::tokenize(std::string_view str, bool special)
Expand Down Expand Up @@ -594,7 +601,7 @@ LLModel::Token LLamaModel::sampleToken() const
return llama_sampler_sample(d_ptr->sampler_chain, d_ptr->ctx, -1);
}

bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const
bool LLamaModel::evalTokens(PromptContext &ctx, std::span<const Token> tokens) const
{
llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1);

Expand Down Expand Up @@ -625,7 +632,7 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
// erase up to n_ctx*contextErase tokens
int n_keep = shouldAddBOS();
int n_past = promptCtx.n_past;
int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase));
int n_discard = std::min(n_past - n_keep, int(contextLength() * promptCtx.contextErase));

assert(n_discard > 0);
if (n_discard <= 0)
Expand All @@ -638,15 +645,70 @@ void LLamaModel::shiftContext(PromptContext &promptCtx)
llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard);

promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard);
promptCtx.n_past = promptCtx.tokens.size();
auto &inp = d_ptr->inputTokens;
inp.erase(inp.begin() + n_keep, inp.begin() + n_keep + n_discard);
promptCtx.n_past = inp.size();
}

int32_t LLamaModel::contextLength() const
{
return llama_n_ctx(d_ptr->ctx);
}

int32_t LLamaModel::inputLength() const
{
return d_ptr->inputTokens.size();
}

void LLamaModel::setTokenizeInputPosition(int32_t pos)
{
assert(pos >= 0);
m_tokenize_last_token = pos ? d_ptr->inputTokens.at(size_t(pos) - 1) : -1; // not serialized
}

auto LLamaModel::computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator
{
assert(ctx.n_past >= 0);
auto pos = size_t(ctx.n_past);
if (pos > d_ptr->inputTokens.size()) {
std::ostringstream ss;
ss << "n_past=" << pos << " is past end of token cache length=" << d_ptr->inputTokens.size();
throw std::out_of_range(ss.str());
}

// find common prefix
auto cacheIt = d_ptr->inputTokens.begin();
auto inputIt = input.begin();
while (cacheIt < d_ptr->inputTokens.end() && inputIt < input.end() && *cacheIt == *inputIt) {
++cacheIt; ++inputIt; ++pos;
}
// tell the caller to ignore the tokens between [begin, inputIt)
return inputIt;
}

void LLamaModel::setModelInputPosition(PromptContext &ctx, int32_t pos)
{
auto &inp = d_ptr->inputTokens;
assert(pos >= 0);
assert(pos <= inp.size());
// truncate token cache to end at the new n_past
if (pos < inp.size())
inp.resize(pos);
ctx.n_past = pos;
}

void LLamaModel::appendInputToken(PromptContext &ctx, Token tok)
{
d_ptr->inputTokens.push_back(tok);
ctx.n_past += 1;
}

auto LLamaModel::inputTokens() const -> std::span<const Token>
{
return d_ptr->inputTokens;
}

const std::vector<LLModel::Token> &LLamaModel::endTokens() const
{
return d_ptr->end_tokens;
Expand Down
24 changes: 16 additions & 8 deletions gpt4all-backend/src/llamamodel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class LLamaModel : public LLModel {
bool isModelLoaded() const override;
size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override;
size_t stateSize() const override;
size_t saveState(std::span<uint8_t> dest) const override;
size_t restoreState(std::span<const uint8_t> src) override;
size_t saveState(std::span<uint8_t> stateOut, std::vector<Token> &inputTokensOut) const override;
size_t restoreState(std::span<const uint8_t> state, std::span<const Token> inputTokens) override;
void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override;
std::vector<GPUDevice> availableGPUDevices(size_t memoryRequired = 0) const override;
Expand All @@ -48,20 +48,23 @@ class LLamaModel : public LLModel {
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;

private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
int32_t contextLength() const override;

protected:
std::vector<Token> tokenize(std::string_view str, bool special) override;
bool isSpecialToken(Token id) const override;
std::string tokenToString(Token id) const override;
void initSampler(PromptContext &ctx) override;
Token sampleToken() const override;
bool evalTokens(PromptContext &ctx, const std::vector<int32_t> &tokens) const override;
bool evalTokens(PromptContext &ctx, std::span<const Token> tokens) const override;
void shiftContext(PromptContext &promptCtx) override;
int32_t contextLength() const override;
int32_t inputLength() const override;
void setTokenizeInputPosition(int32_t pos) override;
auto computeModelInputPosition(PromptContext &ctx, const std::vector<Token> &input)
-> std::vector<Token>::const_iterator override;
void setModelInputPosition(PromptContext &ctx, int32_t pos) override;
void appendInputToken(PromptContext &ctx, Token tok) override;
std::span<const Token> inputTokens() const override;
const std::vector<Token> &endTokens() const override;
bool shouldAddBOS() const override;
int32_t maxContextLength(std::string const &modelPath) const override;
Expand All @@ -70,6 +73,11 @@ class LLamaModel : public LLModel {
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
const EmbModelSpec *spec);

private:
std::unique_ptr<LLamaPrivate> d_ptr;
bool m_supportsEmbedding = false;
bool m_supportsCompletion = false;
};

#endif // LLAMAMODEL_H
Loading

0 comments on commit f07e2e6

Please sign in to comment.