Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable n-sampling for Medusa spec decoding #2495

Merged
merged 4 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions cpp/serve/draft_token_workspace_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,33 @@ DraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens,
void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) {
ICHECK_LE(num_slots, free_slots_.size());
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);
std::vector<int> allocated(free_slots_.begin(), free_slots_.begin() + num_slots);
free_slots_.resize(free_slots_.size() - num_slots);
for (int slot : (*result)) {
ref_count_[slot] = 1;
}
}

void DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots,
const std::vector<int>& initial_ref_count,
std::vector<int>* result) {
ICHECK_LE(num_slots, free_slots_.size());
ICHECK_EQ(num_slots, initial_ref_count.size());
result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);
free_slots_.resize(free_slots_.size() - num_slots);
for (int i = 0; i < num_slots; ++i) {
int slot = (*result)[i];
ICHECK(initial_ref_count[i] > 0);
ref_count_[slot] = initial_ref_count[i];
}
}

void DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) {
std::copy(slots.begin(), slots.end(), std::back_inserter(free_slots_));
for (int slot : slots) {
if (--ref_count_.at(slot) == 0) {
free_slots_.push_back(slot);
ref_count_.erase(slot);
}
}
}

void DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,
Expand Down
10 changes: 10 additions & 0 deletions cpp/serve/draft_token_workspace_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class DraftTokenWorkspaceManagerObj : public Object {
*/
void AllocSlots(int num_slots, std::vector<int>* result);

/*!
* \brief Allocate slots for the draft tokens.
* \param num_slots The number of slots to allocate.
* \param initial_ref_count The initial reference count for each slot.
* \param result The vector to store the allocated slots.
*/
void AllocSlots(int num_slots, const std::vector<int>& initial_ref_count,
std::vector<int>* result);

/*!
* \brief Free the slots.
* \param slots The slots to free.
Expand All @@ -74,6 +83,7 @@ class DraftTokenWorkspaceManagerObj : public Object {
DataType hidden_states_dtype_;
DLDevice device_;
const FunctionTable& ft_;
std::unordered_map<int, int> ref_count_;
};

class DraftTokenWorkspaceManager : public ObjectRef {
Expand Down
5 changes: 3 additions & 2 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ std::pair<NDArray, std::vector<SampleResult>> ApplyLogitProcessorAndSample(
const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits,
const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,
const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,
const std::vector<int>& sample_indices) {
const std::vector<int>& sample_indices, const Array<GenerationConfig>& child_generation_cfg,
const Array<String>& child_request_ids, const std::vector<int>& child_sample_indices) {
// - Update logits.
logit_processor->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);

Expand All @@ -307,7 +308,7 @@ std::pair<NDArray, std::vector<SampleResult>> ApplyLogitProcessorAndSample(
NDArray renormalized_probs = sampler->BatchRenormalizeProbsByTopP(probs_on_device, sample_indices,
request_ids, generation_cfg);
std::vector<SampleResult> sample_results = sampler->BatchSampleTokensWithProbAfterTopP(
renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);
renormalized_probs, child_sample_indices, child_request_ids, child_generation_cfg, rngs);
return {std::move(probs_on_device), std::move(sample_results)};
}

Expand Down
13 changes: 12 additions & 1 deletion cpp/serve/engine_actions/action_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ inline std::vector<RequestStateEntry> GetRunningRequestStateEntries(const Engine

/*!
* \brief Apply the logit processor to the logits and sample one token for each request.
*
* Both the parent request configurations and the child request configurations need to be provided.
* The parent request configurations are used to process the logits, normalize the probabilities.
* The child request configurations are used to sample the tokens.
*
* When the request doesn't have children, the parent and child configurations are the same.
*
* \param logit_processor The logit processor to apply.
* \param sampler The sampler to sample tokens.
* \param logits The logits to process.
Expand All @@ -91,13 +98,17 @@ inline std::vector<RequestStateEntry> GetRunningRequestStateEntries(const Engine
* \param mstates The model states of the requests.
* \param rngs The random generators of the requests.
* \param sample_indices The indices of the requests to sample.
* \param child_generation_cfg The generation configurations of the child requests.
* \param child_request_ids The request ids of the child requests.
* \param child_sample_indices The indices of the child requests to sample.
* \return The processed logits and the sampled results.
*/
std::pair<NDArray, std::vector<SampleResult>> ApplyLogitProcessorAndSample(
const LogitProcessor& logit_processor, const Sampler& sampler, const NDArray& logits,
const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,
const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,
const std::vector<int>& sample_indices);
const std::vector<int>& sample_indices, const Array<GenerationConfig>& child_generation_cfg,
const Array<String>& child_request_ids, const std::vector<int>& child_sample_indices);

} // namespace serve
} // namespace llm
Expand Down
8 changes: 4 additions & 4 deletions cpp/serve/engine_actions/eagle_batch_verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,16 @@ class EagleBatchVerifyActionObj : public EngineActionObj {
std::iota(sample_indices.begin(), sample_indices.end(), 0);

if (engine_config_->speculative_mode == SpeculativeMode::kEagle) {
const auto& [renormalized_probs, sample_results] =
ApplyLogitProcessorAndSample(logit_processor_, sampler_, logits, generation_cfg,
request_ids, mstates, rngs, sample_indices);
const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(
logit_processor_, sampler_, logits, generation_cfg, request_ids, mstates, rngs,
sample_indices, generation_cfg, request_ids, sample_indices);
UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_,
renormalized_probs, hidden_states, estate);
} else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {
for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; draft_id++) {
const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(
logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids,
mstates, rngs, sample_indices);
mstates, rngs, sample_indices, generation_cfg, request_ids, sample_indices);
UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_,
renormalized_probs, hidden_states, estate);
}
Expand Down
113 changes: 65 additions & 48 deletions cpp/serve/engine_actions/eagle_new_request_prefill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,6 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
for (int i = 0; i < num_rsentries; ++i) {
const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;
RequestModelState mstate = rsentry->mstates[model_id];
auto [input_data, input_length] =
ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length);
if (prefill_lengths[i] == -1) {
prefill_lengths[i] = input_length;
} else {
ICHECK_EQ(prefill_lengths[i], input_length);
}
mstate->num_prefilled_tokens += input_length;

ICHECK(mstate->draft_output_tokens.empty());
ICHECK(mstate->draft_token_slots.empty());
if (status_before_prefill[i] == RequestStateStatus::kPending) {
Expand Down Expand Up @@ -127,6 +118,15 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
// Embedding is only needed for the base model in Medusa.
continue;
}
auto [input_data, input_length] =
ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length);
if (prefill_lengths[i] == -1) {
prefill_lengths[i] = input_length;
} else {
ICHECK_EQ(prefill_lengths[i], input_length);
}
mstate->num_prefilled_tokens += input_length;

RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, "start embedding");
// Speculative models shift left the input tokens by 1 when base model has committed tokens.
// Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens.
Expand Down Expand Up @@ -191,22 +191,22 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
LOG(FATAL) << "unreachable";
}

Array<String> request_ids_for_logitproc = request_ids;

Array<String> child_request_ids;
// - Prepare the configurations for the sampler.
// For prefill_inputs which have children, sample
// one token for each rstate that is depending.
// Otherwise, sample a token for the current rstate.
std::vector<int> sample_indices;
std::vector<int> child_sample_indices;
std::vector<RequestStateEntry> rsentries_for_sample;
std::vector<RandomGenerator*> rngs;
std::vector<bool> rsentry_activated;
Array<GenerationConfig> generation_cfg;
sample_indices.reserve(num_rsentries);
Array<GenerationConfig> child_generation_cfg;
child_sample_indices.reserve(num_rsentries);
child_generation_cfg.reserve(num_rsentries);
child_request_ids.reserve(num_rsentries);
rsentries_for_sample.reserve(num_rsentries);
rngs.reserve(num_rsentries);
rsentry_activated.reserve(num_rsentries);
request_ids.clear();
for (int i = 0; i < num_rsentries; ++i) {
const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;
// No sample for rsentries with remaining inputs.
Expand All @@ -217,18 +217,21 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;
for (int child_idx : rsentry->child_indices) {
// Only use base model to judge if we need to add child entries.
if (rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending &&
(rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty() ||
if ((rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending &&
rstates_of_entries[i]
->entries[child_idx]
->mstates[0]
->committed_tokens.empty() ||
fork_rsentry_child_map[i].count(child_idx))) {
// If rstates_of_entries[i]->entries[child_idx] has no committed token,
// the prefill of the current rsentry will unblock
// rstates_of_entries[i]->entries[child_idx],
// and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx].
fork_rsentry_child_map[i].insert(child_idx);
sample_indices.push_back(i);
child_sample_indices.push_back(i);
rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]);
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
child_request_ids.push_back(rsentry->request->id);
child_generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);

// We only fork the first `num_child_to_activate` children.
Expand Down Expand Up @@ -258,60 +261,67 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
}
if (rsentry->child_indices.empty()) {
// If rsentry has no child, we sample a token for itself.
sample_indices.push_back(i);
child_sample_indices.push_back(i);
rsentries_for_sample.push_back(rsentry);
request_ids.push_back(rsentry->request->id);
generation_cfg.push_back(rsentry->request->generation_cfg);
child_request_ids.push_back(rsentry->request->id);
child_generation_cfg.push_back(rsentry->request->generation_cfg);
rngs.push_back(&rsentry->rng);
rsentry_activated.push_back(true);
}
}

// - Prepare input for logit processor.
ICHECK(logits_for_sample.defined());
Array<GenerationConfig> generation_cfg_for_logitproc;
Array<GenerationConfig> generation_cfg;
Array<RequestModelState> mstates_for_logitproc;
generation_cfg_for_logitproc.reserve(num_rsentries);
std::vector<int> sample_indices(num_rsentries);
generation_cfg.reserve(num_rsentries);
mstates_for_logitproc.reserve(num_rsentries);
std::iota(sample_indices.begin(), sample_indices.end(), 0);
for (int i = 0; i < num_rsentries; ++i) {
generation_cfg_for_logitproc.push_back(prefill_inputs[i].rsentry->request->generation_cfg);
generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg);
mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[model_id]);
}
if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) {
const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(
logit_processor_, sampler_, logits_for_sample, generation_cfg_for_logitproc,
request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices);
logit_processor_, sampler_, logits_for_sample, generation_cfg, request_ids,
mstates_for_logitproc, rngs, sample_indices, child_generation_cfg, child_request_ids,
child_sample_indices);
if (model_id == 0) {
UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated,
sample_results);
// Add the sampled token as an input of the eagle models.
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
for (int mid = 1; mid < static_cast<int>(models_.size()); ++mid) {
TokenData token_data =
Downcast<TokenData>(rsentries_for_sample[i]->mstates[mid]->inputs.back());
std::vector<int32_t> token_ids = {token_data->token_ids.begin(),
token_data->token_ids.end()};
token_ids.push_back(sample_results[i].sampled_token_id.first);
int ninputs = static_cast<int>(rsentries_for_sample[i]->mstates[mid]->inputs.size());
rsentries_for_sample[i]->mstates[mid]->inputs.Set(
ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end())));
if (engine_config_->speculative_mode == SpeculativeMode::kEagle) {
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
for (int mid = 1; mid < static_cast<int>(models_.size()); ++mid) {
TokenData token_data =
Downcast<TokenData>(rsentries_for_sample[i]->mstates[mid]->inputs.back());
std::vector<int32_t> token_ids = {token_data->token_ids.begin(),
token_data->token_ids.end()};
token_ids.push_back(sample_results[i].sampled_token_id.first);
int ninputs =
static_cast<int>(rsentries_for_sample[i]->mstates[mid]->inputs.size());
rsentries_for_sample[i]->mstates[mid]->inputs.Set(
ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end())));
}
}
}
} else {
// - Slice and save hidden_states_for_sample
UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id,
renormalized_probs, hidden_states_for_sample,
estate);
estate, child_sample_indices);
}
} else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {
for (int draft_id = 0; draft_id < engine_config_->spec_draft_length; ++draft_id) {
const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(
logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg_for_logitproc,
request_ids_for_logitproc, mstates_for_logitproc, rngs, sample_indices);
logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids,
mstates_for_logitproc, rngs, sample_indices, child_generation_cfg, child_request_ids,
child_sample_indices);

UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id,
renormalized_probs,
/*hidden_states=*/ObjectRef{nullptr}, estate);
UpdateRequestStatesWithDraftProposals(
rsentries_for_sample, sample_results, model_id, renormalized_probs,
/*hidden_states=*/ObjectRef{nullptr}, estate, child_sample_indices);
}
}
}
Expand All @@ -328,8 +338,15 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
const std::vector<RequestStateEntry>& rsentries_for_sample,
const std::vector<SampleResult>& sample_results, int model_id,
const NDArray& renormalized_probs, const ObjectRef& hidden_states_for_sample,
EngineState estate) {
draft_token_workspace_manager_->AllocSlots(rsentries_for_sample.size(), &draft_token_slots_);
EngineState estate, const std::vector<int>& sample_indices) {
std::vector<int> reuse_count(renormalized_probs->shape[0], 0);
for (int i = 0; i < static_cast<int>(sample_indices.size()); ++i) {
// The same probability may be sampled multiple times.
reuse_count[sample_indices[i]]++;
}
draft_token_workspace_manager_->AllocSlots(renormalized_probs->shape[0], reuse_count,
&draft_token_slots_);

models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,
&model_workspaces_[0].draft_probs_storage);
if (engine_config_->speculative_mode == SpeculativeMode::kEagle &&
Expand All @@ -338,8 +355,8 @@ class EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {
&model_workspaces_[0].draft_hidden_states_storage);
}
for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {
rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(sample_results[i],
draft_token_slots_[i]);
rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(
sample_results[i], draft_token_slots_[sample_indices[i]]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/serve/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ class ModelImpl : public ModelObj {
ModelMetadata GetMetadata() const final { return ft_.model_metadata_; }

int GetNumAvailablePages() const final {
if (this->kind == KVStateKind::kRNNState) {
if (this->kind == KVStateKind::kRNNState || this->kind == KVStateKind::kNone) {
// RNNState does not introduce new page at runtime
return std::numeric_limits<int>::max();
} else {
Expand Down
Loading