Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PrefixCache] Defer sequence extension #2654

Merged
merged 2 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 40 additions & 37 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,35 +99,37 @@ void ProcessFinishedRequestStateEntries(std::vector<RequestStateEntry> finished_
}
}

void UpdatePrefixCache(Array<Request> requests, EngineState estate) {
for (Request request : requests) {
RequestState rstate = estate->GetRequestState(request);
void UpdatePrefixCache(const std::vector<RequestState>& rstates, EngineState estate) {
NVTXScopedRange nvtx_scope("Update prefix cache");
std::vector<int32_t> token_ids;
for (RequestState rstate : rstates) {
for (const RequestStateEntry& rsentry : rstate->entries) {
if (estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {
if (!rsentry->mstates[0]->prefilled_inputs.empty()) {
// Notify the prefix cache of the newly prefilled data.
for (Data data : rsentry->mstates[0]->prefilled_inputs) {
const TokenDataNode* token_data = data.as<TokenDataNode>();
estate->prefix_cache->ExtendSequence(rsentry->mstates[0]->internal_id,
token_data->token_ids);
}
rsentry->mstates[0]->prefilled_inputs.clear();
if (!rsentry->mstates[0]->prefilled_inputs.empty()) {
// Notify the prefix cache of the newly prefilled data.
token_ids.clear();
for (Data data : rsentry->mstates[0]->prefilled_inputs) {
const TokenDataNode* token_data = data.as<TokenDataNode>();
token_ids.reserve(token_ids.size() + token_data->token_ids.size());
token_ids.insert(token_ids.end(), token_data->token_ids->data,
token_data->token_ids->data + token_data->token_ids.size());
}
if (rsentry->mstates[0]->cached_committed_tokens <
static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) - 1) {
// Notify the prefix cache of the newly decoded data, except the last token as it is not
// in KVCache yet.
std::vector<int64_t> tokens;
tokens.reserve((static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) -
rsentry->mstates[0]->cached_committed_tokens));
for (int i = rsentry->mstates[0]->cached_committed_tokens;
i < static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) - 1; ++i) {
tokens.push_back(rsentry->mstates[0]->committed_tokens[i].GetTokenId());
}
estate->prefix_cache->ExtendSequence(rsentry->mstates[0]->internal_id, IntTuple(tokens));
rsentry->mstates[0]->cached_committed_tokens =
static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) - 1;
estate->prefix_cache->ExtendSequence(rsentry->mstates[0]->internal_id, token_ids);
rsentry->mstates[0]->prefilled_inputs.clear();
}
if (rsentry->mstates[0]->cached_committed_tokens <
static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) - 1) {
// Notify the prefix cache of the newly decoded data, except the last token as it is not
// in KVCache yet.
token_ids.clear();
token_ids.reserve((static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) -
rsentry->mstates[0]->cached_committed_tokens));
for (int i = rsentry->mstates[0]->cached_committed_tokens;
i < static_cast<int32_t>(rsentry->mstates[0]->committed_tokens.size()) - 1; ++i) {
token_ids.push_back(rsentry->mstates[0]->committed_tokens[i].GetTokenId());
}
estate->prefix_cache->ExtendSequence(rsentry->mstates[0]->internal_id, token_ids);
rsentry->mstates[0]->cached_committed_tokens =
static_cast<int64_t>(rsentry->mstates[0]->committed_tokens.size()) - 1;
}
}
}
Expand All @@ -139,14 +141,17 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
int64_t max_single_sequence_length,
Optional<EventTraceRecorder> trace_recorder) {
NVTXScopedRange nvtx_scope("EngineAction postproc");
int num_requests = requests.size();
std::vector<RequestState> rstates;
std::vector<RequestStateEntry> finished_rsentries;
finished_rsentries.reserve(requests.size());

Array<RequestStreamOutput> callback_delta_outputs;
callback_delta_outputs.reserve(requests.size());
rstates.reserve(num_requests);
finished_rsentries.reserve(num_requests);
callback_delta_outputs.reserve(num_requests);

for (Request request : requests) {
RequestState rstate = estate->GetRequestState(request);
for (int i = 0; i < num_requests; ++i) {
RequestState rstate = estate->GetRequestState(requests[i]);
rstates.push_back(rstate);
for (const RequestStateEntry& rsentry : rstate->entries) {
for (Data data : rsentry->mstates[0]->prefilled_inputs) {
// note that we are counting prefill tokens across all branches
Expand All @@ -155,15 +160,13 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, Array<Mo
}
}

{
NVTXScopedRange nvtx_scope("ActionStepPostProcess updating prefix cache");
UpdatePrefixCache(requests, estate);
}
UpdatePrefixCache(rstates, estate);

// - Collect new generated tokens and finish reasons for requests.
for (Request request : requests) {
for (int r = 0; r < num_requests; ++r) {
Request request = requests[r];
int n = request->generation_cfg->n;
RequestState rstate = estate->GetRequestState(request);
RequestState rstate = rstates[r];
Array<IntTuple> group_delta_token_ids;
Array<Array<String>> group_delta_logprob_json_strs;
Array<Optional<String>> group_finish_reason;
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ class BatchDecodeActionObj : public EngineActionObj {
NDArray probs_on_device =
logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids);

// - Commit the prefix cache changes from previous round of action.
// Note: we commit prefix cache changes here to overlap this commit with the GPU execution.
estate->prefix_cache->CommitSequenceExtention();

// - Sample tokens.
// Fill range [0, num_rsentries) into `sample_indices`.
std::vector<int> sample_indices(num_rsentries);
Expand Down
9 changes: 5 additions & 4 deletions cpp/serve/engine_actions/batch_prefill_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,17 +480,18 @@ void BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults(
}
}

IntTuple BatchPrefillBaseActionObj::GetConcatPrefillInputData(const RequestModelState& mstate) {
std::vector<int64_t> tokens;
std::vector<int32_t> BatchPrefillBaseActionObj::GetConcatPrefillInputData(
const RequestModelState& mstate) {
std::vector<int32_t> tokens;
for (Data data : mstate->inputs) {
if (const TokenDataNode* token_data = data.as<TokenDataNode>()) {
tokens.reserve(tokens.size() + token_data->GetLength());
tokens.insert(tokens.end(), token_data->token_ids.begin(), token_data->token_ids.end());
} else {
return IntTuple({});
return {};
}
}
return IntTuple(tokens);
return tokens;
}

void BatchPrefillBaseActionObj::PopPrefillInputData(const RequestModelState& mstate,
Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/engine_actions/batch_prefill_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class BatchPrefillBaseActionObj : public EngineActionObj {
* \param mstate The RequestModelState whose input data is to be concatenated.
* \return The concatenate IntTuple.
*/
IntTuple GetConcatPrefillInputData(const RequestModelState& mstate);
std::vector<int32_t> GetConcatPrefillInputData(const RequestModelState& mstate);

/*!
* \brief Pop the prefix tokens of the RequestModelState input data array.
Expand Down
4 changes: 4 additions & 0 deletions cpp/serve/engine_actions/batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ class BatchVerifyActionObj : public EngineActionObj {
NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits(
logits, generation_cfg, request_ids, &cum_verify_lengths);

// - Commit the prefix cache changes from previous round of action.
// Note: we commit prefix cache changes here to overlap this commit with the GPU execution.
estate->prefix_cache->CommitSequenceExtention();

std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
Expand Down
5 changes: 5 additions & 0 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
// - Compute probability distributions.
NDArray probs_on_device = logit_processor_->ComputeProbsFromLogits(
logits, generation_cfg, request_ids, &cum_verify_lengths);

// - Commit the prefix cache changes from previous round of action.
// Note: we commit prefix cache changes here to overlap this commit with the GPU execution.
estate->prefix_cache->CommitSequenceExtention();

std::vector<int> sample_indices(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
NDArray renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
}
if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&
!estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {
IntTuple tokens = GetConcatPrefillInputData(rsentry->mstates[0]);
if (!tokens.size()) {
std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);
if (tokens.empty()) {
// If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing
// and return.
return;
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/engine_actions/new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,8 @@ class NewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
}
if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&
!estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {
IntTuple tokens = GetConcatPrefillInputData(rsentry->mstates[0]);
if (!tokens.size()) {
std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);
if (tokens.empty()) {
// If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing
// and return.
return;
Expand Down
54 changes: 42 additions & 12 deletions cpp/serve/prefix_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ class PrefixCacheImpl : public PrefixCacheObj {
* \param attention_sink_size The attention sink size for the sequence, 0 by default.
* \return The matched result.
*/
PrefixCacheMatchedResult InsertSequence(int64_t seq_id, IntTuple tokens, int sliding_window_size,
int attention_sink_size) final {
PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,
int sliding_window_size, int attention_sink_size) final {
CHECK_NE(sliding_window_size, 0);
CHECK_GE(attention_sink_size, 0);
CHECK(seq_states_.find(seq_id) == seq_states_.end());
CHECK(seq_sliding_window_infos_.find(seq_id) == seq_sliding_window_infos_.end());
CHECK(!tokens.empty());
CommitSequenceExtention();
tokens.pop_back();
auto [matched_offset, matched_seqs] = radix_tree_->MatchPrefix(tokens);
std::pair<int, size_t> sliding_window_info{sliding_window_size, attention_sink_size};
IntTuple popped_tokens = IntTuple(std::vector<int64_t>(tokens.begin(), tokens.end() - 1));
auto [matched_offset, matched_seqs] = radix_tree_->MatchPrefix(popped_tokens);
// No prefix matched, directly adding new sequence.
if (!matched_offset) {
radix_tree_->AddSequence(seq_id);
Expand Down Expand Up @@ -142,9 +144,25 @@ class PrefixCacheImpl : public PrefixCacheObj {
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if the given sequence id is not valid or active.
*/
void ExtendSequence(int64_t seq_id, IntTuple tokens) final {
CHECK(seq_states_.at(seq_id) == SequenceState::kActive);
radix_tree_->ExtendSequence(seq_id, tokens);
void ExtendSequence(int64_t seq_id, std::vector<int32_t> tokens) final {
const auto& it = seq_states_.find(seq_id);
CHECK(it == seq_states_.end() || it->second == SequenceState::kActive);
uncommitted_extended_token_ids_.emplace_back(seq_id, std::move(tokens));
}

void CommitSequenceExtention() final {
if (uncommitted_extended_token_ids_.empty()) {
return;
}
NVTXScopedRange nvtx_scope("PrefixCache commit sequence extension");
for (const auto& [seq_id, uncommitted_token_ids] : uncommitted_extended_token_ids_) {
if (!HasSequence(seq_id)) {
// The sequence has been removed. Hence no action is needed.
continue;
}
radix_tree_->ExtendSequence(seq_id, uncommitted_token_ids);
}
uncommitted_extended_token_ids_.clear();
}

/*!
Expand All @@ -154,6 +172,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
* \throw Error if the given sequence id is not valid or active.
*/
void RollBackSequence(int64_t seq_id, size_t num_tokens) final {
CommitSequenceExtention();
CHECK(seq_states_.at(seq_id) == SequenceState::kActive);
radix_tree_->RollBackSequence(seq_id, num_tokens);
}
Expand All @@ -167,6 +186,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
* \throw Error if the given sequence id is not valid.
*/
void RecycleSequence(int64_t seq_id, bool lazy = true) final {
CommitSequenceExtention();
CHECK(seq_states_.at(seq_id) == SequenceState::kActive);
CHECK(recycling_seq_lrus_.find(seq_id) == recycling_seq_lrus_.end());
if (lazy && max_num_recycling_seqs_ != 0) {
Expand Down Expand Up @@ -236,6 +256,7 @@ class PrefixCacheImpl : public PrefixCacheObj {
reversed_recycling_seq_lrus_.clear();
seq_states_.clear();
seq_sliding_window_infos_.clear();
uncommitted_extended_token_ids_.clear();
lru_counter_ = 0;
}

Expand Down Expand Up @@ -304,6 +325,12 @@ class PrefixCacheImpl : public PrefixCacheObj {
* non-negative and used when sliding window size is positive.
*/
std::unordered_map<int64_t, std::pair<int, size_t>> seq_sliding_window_infos_;
/*!
* \brief The collection of uncommitted extended token ids of sequences.
* The "ExtendSequence" method only lazily add token ids into this collection,
* and these uncommitted token ids will be committed when needed.
*/
std::vector<std::pair<int64_t, std::vector<int32_t>>> uncommitted_extended_token_ids_;
}; // namespace serve

TVM_REGISTER_OBJECT_TYPE(PrefixCacheImpl);
Expand All @@ -322,8 +349,8 @@ class NoPrefixCache : public PrefixCacheObj {
* \param attention_sink_size The attention sink size for the sequence, 0 by default.
* \return The matched result.
*/
PrefixCacheMatchedResult InsertSequence(int64_t seq_id, IntTuple tokens, int sliding_window_size,
int attention_sink_size) final {
PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,
int sliding_window_size, int attention_sink_size) final {
// Since there is no prefix cache, always return as new sequence.
return PrefixCacheMatchedResult{0, -1, -1, 0};
}
Expand All @@ -334,9 +361,12 @@ class NoPrefixCache : public PrefixCacheObj {
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if called since this should never be called.
*/
void ExtendSequence(int64_t seq_id, IntTuple tokens) final {
// Since there is no prefix cache, this method should never be called.
LOG(FATAL) << "Unreachable code.";
void ExtendSequence(int64_t seq_id, std::vector<int32_t> tokens) final {
// No-op;
}

void CommitSequenceExtention() final {
// No-op;
}

/*!
Expand Down
10 changes: 7 additions & 3 deletions cpp/serve/prefix_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,21 @@ class PrefixCacheObj : public Object {
* \param attention_sink_size The attention sink size for the sequence, 0 by default.
* \return The matched result.
*/
virtual PrefixCacheMatchedResult InsertSequence(int64_t seq_id, IntTuple tokens,
virtual PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,
int sliding_window_size = -1,
int attention_sink_size = 0) = 0;

/*!
* \brief Extend a sequence with new tokenized sequence suffix.
* \param seq_id The sequence to be extneded.
* This extension might be cached and lazily committed later.
* \param seq_id The sequence to be extended.
* \param tokens The tokens of tokenized sequence suffix to extend.
* \throw Error if the given sequence id is not valid or active.
*/
virtual void ExtendSequence(int64_t seq_id, IntTuple tokens) = 0;
virtual void ExtendSequence(int64_t seq_id, std::vector<int32_t> tokens) = 0;

/*! \brief Commit the cached sequence extension from "ExtendSequence". */
virtual void CommitSequenceExtention() = 0;

/*!
* \brief Roll back a sequence by number of tokens.
Expand Down
Loading
Loading