Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,45 +123,27 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
if (rsentry->child_indices.empty()) {
models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id);
}
// Shift the input tokens by 1 for eagle models.
if (model_id == 0) {
for (int j = 1; j < static_cast<int>(models_.size()); ++j) {
ICHECK(rsentry->mstates[j]->inputs.size());
TokenData token_data = Downcast<TokenData>(rsentry->mstates[j]->inputs[0]);
rsentry->mstates[j]->inputs.Set(
0, TokenData(
IntTuple(token_data->token_ids.begin() + 1, token_data->token_ids.end())));
}
}
}
request_internal_ids.push_back(mstate->internal_id);
RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding");
// Speculative models shift left the input tokens by 1 when base model has committed tokens.
// Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens.
int embed_offset =
prefill_inputs[i].rsentry->mstates[model_id]->committed_tokens.empty() ? 0 : 1;
for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {
if (j == static_cast<int>(input_data.size()) - 1) {
std::vector<int32_t> tail_tokens;
TokenData tk_data = Downcast<TokenData>(input_data[j]);
CHECK(tk_data.defined());
for (int k = embed_offset; k < static_cast<int>(tk_data->token_ids.size()); ++k) {
tail_tokens.push_back(tk_data->token_ids[k]);
}
embeddings = models_[model_id]->TokenEmbed(
{tail_tokens.begin(), tail_tokens.end()},
/*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr,
/*offset=*/cum_prefill_length);
cum_prefill_length += input_data[j]->GetLength();
cum_prefill_length -= embed_offset;
} else {
embeddings = input_data[i]->GetEmbedding(
models_[model_id],
/*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr,
/*offset=*/cum_prefill_length);
cum_prefill_length += input_data[j]->GetLength();
}
}
if (embed_offset > 0) {
std::vector<int32_t> new_tokens = {prefill_inputs[i]
.rsentry->mstates[model_id]
->committed_tokens.back()
.sampled_token_id.first};
embeddings =
models_[model_id]->TokenEmbed({new_tokens.begin(), new_tokens.end()},
/*dst=*/&model_workspaces_[model_id].embeddings,
/*offset=*/cum_prefill_length);
cum_prefill_length += new_tokens.size();
embeddings = input_data[j]->GetEmbedding(
models_[model_id],
/*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr,
/*offset=*/cum_prefill_length);
cum_prefill_length += input_data[j]->GetLength();
}
RECORD_EVENT(trace_recorder_, rsentry->request->id, "finish embedding");
}
Expand Down Expand Up @@ -238,6 +220,11 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
generation_cfg.clear();
for (int i = 0; i < num_rsentries; ++i) {
const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;
// No sample for rsentries with remaining inputs.
if (!rsentry->mstates[0]->inputs.empty()) {
continue;
}

int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;
for (int child_idx : rsentry->child_indices) {
// Only use base model to judge if we need to add child entries.
Expand Down Expand Up @@ -310,6 +297,17 @@ class EagleNewRequestPrefillActionObj : public EngineActionObj {
rsentries_for_sample[i]->mstates[mid]->inputs.push_back(
TokenData(std::vector<int64_t>{sample_results[i].sampled_token_id.first}));
}
if (mid > 0) {
// Add the sampled token as an input of the eagle models.
TokenData token_data =
Downcast<TokenData>(rsentries_for_sample[i]->mstates[mid]->inputs.back());
std::vector<int32_t> token_ids = {token_data->token_ids.begin(),
token_data->token_ids.end()};
token_ids.push_back(sample_results[i].sampled_token_id.first);
int ninputs = static_cast<int>(rsentries_for_sample[i]->mstates[mid]->inputs.size());
rsentries_for_sample[i]->mstates[mid]->inputs.Set(
ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end())));
}
}
// Only base model trigger timing records.
if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) {
Expand Down