Skip to content

Commit

Permalink
[Serving] Add ICHECK for running batch size (#2465)
Browse files Browse the repository at this point in the history
This PR adds ICHECK to make sure that the running batch size
in BatchDecode and BatchDraft does not exceed the `max_num_sequence`
as in the engine config.

The prefill actions should keep this invariant. And the ICHECKs
added mainly serve for internal error detection and report purpose.
  • Loading branch information
MasterJH5574 authored May 29, 2024
1 parent f2c1582 commit e90f2e7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 34 deletions.
52 changes: 26 additions & 26 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,21 @@ class EngineImpl : public Engine {
ICHECK_GT(n->models_.size(), 1U);
switch (engine_config->speculative_mode) {
case SpeculativeMode::kEagle:
n->actions_ = {
EngineAction::EagleNewRequestPrefill(n->models_, //
logit_processor, //
sampler, //
n->model_workspaces_, //
draft_token_workspace_manager, //
engine_config, //
model_configs, //
n->trace_recorder_),
EngineAction::EagleBatchDraft(n->models_, logit_processor, sampler,
n->model_workspaces_, draft_token_workspace_manager,
n->trace_recorder_, engine_config->spec_draft_length),
EngineAction::EagleBatchVerify(n->models_, logit_processor, sampler,
n->model_workspaces_, draft_token_workspace_manager,
engine_config, n->trace_recorder_)};
n->actions_ = {EngineAction::EagleNewRequestPrefill(n->models_, //
logit_processor, //
sampler, //
n->model_workspaces_, //
draft_token_workspace_manager, //
engine_config, //
model_configs, //
n->trace_recorder_),
EngineAction::EagleBatchDraft(
n->models_, logit_processor, sampler, n->model_workspaces_,
draft_token_workspace_manager, engine_config, n->trace_recorder_,
engine_config->spec_draft_length),
EngineAction::EagleBatchVerify(
n->models_, logit_processor, sampler, n->model_workspaces_,
draft_token_workspace_manager, engine_config, n->trace_recorder_)};
break;
case SpeculativeMode::kMedusa:
n->actions_ = {EngineAction::EagleNewRequestPrefill(n->models_, //
Expand All @@ -191,22 +191,22 @@ class EngineImpl : public Engine {
model_configs, //
n->trace_recorder_),
EngineAction::BatchDraft(n->models_, logit_processor, sampler, n->model_workspaces_,
draft_token_workspace_manager, n->trace_recorder_,
engine_config->spec_draft_length),
draft_token_workspace_manager, engine_config,
n->trace_recorder_, engine_config->spec_draft_length),
EngineAction::BatchVerify(n->models_, logit_processor, sampler, n->model_workspaces_,
draft_token_workspace_manager, engine_config,
n->trace_recorder_)};
}
} else {
n->actions_ = {
EngineAction::NewRequestPrefill(n->models_, //
logit_processor, //
sampler, //
n->model_workspaces_, //
engine_config, //
model_configs, //
n->trace_recorder_),
EngineAction::BatchDecode(n->models_, logit_processor, sampler, n->trace_recorder_)};
n->actions_ = {EngineAction::NewRequestPrefill(n->models_, //
logit_processor, //
sampler, //
n->model_workspaces_, //
engine_config, //
model_configs, //
n->trace_recorder_),
EngineAction::BatchDecode(n->models_, logit_processor, sampler, engine_config,
n->trace_recorder_)};
}
// - Automatically set the threading backend max concurrency.
n->engine_config_ = engine_config;
Expand Down
8 changes: 7 additions & 1 deletion cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,13 @@ class EngineAction : public ObjectRef {
* \param models The model to run decode in. When there are multiple
* models, the `Step` function of the created action will not take effect.
* \param sampler The sampler to sample new tokens.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \return The created action object.
*/
static EngineAction BatchDecode(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder);
Sampler sampler, EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder);

/*!
* \brief Create the action that runs one-step speculative draft proposal for
Expand All @@ -111,13 +113,15 @@ class EngineAction : public ObjectRef {
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction BatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder, int draft_length);

/*!
Expand All @@ -129,13 +133,15 @@ class EngineAction : public ObjectRef {
* \param sampler The sampler to sample new tokens.
* \param model_workspaces The workspace of each model.
* \param draft_token_workspace_manager The draft token workspace manager.
* \param engine_config The engine config.
* \param trace_recorder The event trace recorder for requests.
* \param draft_length The number of draft proposal rounds.
* \return The created action object.
*/
static EngineAction EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder,
int draft_length = 4);

Expand Down
18 changes: 13 additions & 5 deletions cpp/serve/engine_actions/batch_decode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ namespace serve {
class BatchDecodeActionObj : public EngineActionObj {
public:
explicit BatchDecodeActionObj(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, Optional<EventTraceRecorder> trace_recorder)
Sampler sampler, EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
engine_config_(std::move(engine_config)),
trace_recorder_(std::move(trace_recorder)) {}

Array<Request> Step(EngineState estate) final {
Expand Down Expand Up @@ -63,6 +65,10 @@ class BatchDecodeActionObj : public EngineActionObj {
ICHECK_GT(num_rsentries, 0)
<< "There should be at least one request state entry that can run decode. "
"Possible failure reason: none of the prefill phase of the running requests is finished";
ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)
<< "The number of running requests exceeds the max number of sequence in EngineConfig. "
"Possible failure reason: the prefill action allows new sequence in regardless of the "
"max num sequence.";
// Collect
// - the last committed token,
// - the request id,
Expand Down Expand Up @@ -154,16 +160,18 @@ class BatchDecodeActionObj : public EngineActionObj {
LogitProcessor logit_processor_;
/*! \brief The sampler to sample new tokens. */
Sampler sampler_;
/*! \brief The engine config. */
EngineConfig engine_config_;
/*! \brief Event trace recorder. */
Optional<EventTraceRecorder> trace_recorder_;
};

EngineAction EngineAction::BatchDecode(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler,
Sampler sampler, EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder) {
return EngineAction(
make_object<BatchDecodeActionObj>(std::move(models), std::move(logit_processor),
std::move(sampler), std::move(trace_recorder)));
return EngineAction(make_object<BatchDecodeActionObj>(
std::move(models), std::move(logit_processor), std::move(sampler), std::move(engine_config),
std::move(trace_recorder)));
}

} // namespace serve
Expand Down
14 changes: 13 additions & 1 deletion cpp/serve/engine_actions/batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ class BatchDraftActionObj : public EngineActionObj {
explicit BatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor, Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder, int draft_length)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
model_workspaces_(std::move(model_workspaces)),
draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),
engine_config_(std::move(engine_config)),
trace_recorder_(std::move(trace_recorder)),
draft_length_(draft_length) {
ICHECK_GT(draft_length_, 0);
Expand All @@ -56,6 +58,13 @@ class BatchDraftActionObj : public EngineActionObj {
auto tstart = std::chrono::high_resolution_clock::now();

int num_rsentries = running_rsentries.size();
ICHECK_GT(num_rsentries, 0)
<< "There should be at least one request state entry that can run decode. "
"Possible failure reason: none of the prefill phase of the running requests is finished";
ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)
<< "The number of running requests exceeds the max number of sequence in EngineConfig. "
"Possible failure reason: the prefill action allows new sequence in regardless of the "
"max num sequence.";
Array<String> request_ids;
std::vector<int64_t> request_internal_ids;
Array<GenerationConfig> generation_cfg;
Expand Down Expand Up @@ -172,6 +181,8 @@ class BatchDraftActionObj : public EngineActionObj {
std::vector<ModelWorkspace> model_workspaces_;
/*! \brief The draft token workspace manager. */
DraftTokenWorkspaceManager draft_token_workspace_manager_;
/*! \brief The engine config. */
EngineConfig engine_config_;
/*! \brief Event trace recorder. */
Optional<EventTraceRecorder> trace_recorder_;
/*! \brief Draft proposal length */
Expand All @@ -183,12 +194,13 @@ class BatchDraftActionObj : public EngineActionObj {
EngineAction EngineAction::BatchDraft(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder,
int draft_length) {
return EngineAction(make_object<BatchDraftActionObj>(
std::move(models), std::move(logit_processor), std::move(sampler),
std::move(model_workspaces), std::move(draft_token_workspace_manager),
std::move(trace_recorder), draft_length));
std::move(engine_config), std::move(trace_recorder), draft_length));
}

} // namespace serve
Expand Down
15 changes: 14 additions & 1 deletion cpp/serve/engine_actions/eagle_batch_draft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ class EagleBatchDraftActionObj : public EngineActionObj {
explicit EagleBatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor,
Sampler sampler, std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder, int draft_length)
: models_(std::move(models)),
logit_processor_(std::move(logit_processor)),
sampler_(std::move(sampler)),
model_workspaces_(std::move(model_workspaces)),
draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),
engine_config_(std::move(engine_config)),
trace_recorder_(std::move(trace_recorder)),
draft_length_(draft_length) {
ICHECK_GT(draft_length_, 0);
Expand All @@ -56,6 +58,14 @@ class EagleBatchDraftActionObj : public EngineActionObj {
auto tstart = std::chrono::high_resolution_clock::now();

int num_rsentries = running_rsentries.size();
ICHECK_GT(num_rsentries, 0)
<< "There should be at least one request state entry that can run decode. "
"Possible failure reason: none of the prefill phase of the running requests is finished";
ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)
<< "The number of running requests exceeds the max number of sequence in EngineConfig. "
"Possible failure reason: the prefill action allows new sequence in regardless of the "
"max num sequence.";

Array<String> request_ids;
std::vector<int64_t> request_internal_ids;
Array<GenerationConfig> generation_cfg;
Expand Down Expand Up @@ -189,6 +199,8 @@ class EagleBatchDraftActionObj : public EngineActionObj {
std::vector<ModelWorkspace> model_workspaces_;
/*! \brief The draft token workspace manager. */
DraftTokenWorkspaceManager draft_token_workspace_manager_;
/*! \brief The engine config. */
EngineConfig engine_config_;
/*! \brief Event trace recorder. */
Optional<EventTraceRecorder> trace_recorder_;
/*! \brief Draft proposal length */
Expand All @@ -201,12 +213,13 @@ EngineAction EngineAction::EagleBatchDraft(Array<Model> models, LogitProcessor l
Sampler sampler,
std::vector<ModelWorkspace> model_workspaces,
DraftTokenWorkspaceManager draft_token_workspace_manager,
EngineConfig engine_config,
Optional<EventTraceRecorder> trace_recorder,
int draft_length) {
return EngineAction(make_object<EagleBatchDraftActionObj>(
std::move(models), std::move(logit_processor), std::move(sampler),
std::move(model_workspaces), std::move(draft_token_workspace_manager),
std::move(trace_recorder), draft_length));
std::move(engine_config), std::move(trace_recorder), draft_length));
}

} // namespace serve
Expand Down

0 comments on commit e90f2e7

Please sign in to comment.