Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] MicroServing API refactor #3071

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ Result<DisaggConfig> DisaggConfig::FromJSON(const picojson::object& config) {
DisaggConfig res;
std::optional<std::string> kind = json::LookupOptional<std::string>(config, "kind");
if (kind.has_value()) {
if (kind.value() == "prepare_prefill") {
res.kind = DisaggRequestKind::kPreparePrefill;
} else if (kind.value() == "remote_prefill") {
res.kind = DisaggRequestKind::kRemotePrefill;
} else if (kind.value() == "start_decode") {
res.kind = DisaggRequestKind::kStartDecode;
if (kind.value() == "prepare_receive") {
res.kind = DisaggRequestKind::kPrepareReceive;
} else if (kind.value() == "remote_send") {
res.kind = DisaggRequestKind::kRemoteSend;
} else if (kind.value() == "start_generation") {
res.kind = DisaggRequestKind::kStartGeneration;
} else {
return TResult::Error("Unknown disaggregation request kind " + kind.value());
}
Expand Down Expand Up @@ -125,16 +125,16 @@ Result<DisaggConfig> DisaggConfig::FromJSON(const picojson::object& config) {
picojson::object DisaggConfig::AsJSON() const {
picojson::object config;
switch (kind) {
case DisaggRequestKind::kPreparePrefill: {
config["kind"] = picojson::value("prepare_prefill");
case DisaggRequestKind::kPrepareReceive: {
config["kind"] = picojson::value("prepare_receive");
break;
}
case DisaggRequestKind::kRemotePrefill: {
config["kind"] = picojson::value("remote_prefill");
case DisaggRequestKind::kRemoteSend: {
config["kind"] = picojson::value("remote_send");
break;
}
case DisaggRequestKind::kStartDecode: {
config["kind"] = picojson::value("start_decode");
case DisaggRequestKind::kStartGeneration: {
config["kind"] = picojson::value("start_generation");
break;
}
default:
Expand Down
12 changes: 6 additions & 6 deletions cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ enum class SpecialRequestKind : int {

enum class DisaggRequestKind : int {
kNone = 0,
kPreparePrefill = 1,
kRemotePrefill = 2,
kStartDecode = 3,
kPrepareReceive = 1,
kRemoteSend = 2,
kStartGeneration = 3,
};

/*! \brief Controls the behavior of inference with grammar constraint. */
Expand All @@ -70,11 +70,11 @@ class DisaggConfig {
// "kv_window_begin" and "kv_window_end" denote the KV interval of interests.
// "kv_window_end" supports Python style negative indexing.
// The concrete meaning varies for different special request kind:
// - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes
// - For "prepare_receive", the begin is always 0, and "[0:end]" denotes
// the KV range to prefill on a prefill instance.
// - For "remote_prefill", "[begin:end]" means the KV range to compute prefill
// - For "remote_send", "[begin:end]" means the KV range to compute prefill
// and send to the decode instance.
// - For "start_decode", the end is always nullopt, and "[begin:]" denotes
// - For "start_generation", the end is always nullopt, and "[begin:]" denotes
// the KV range to prefill locally on the decode instance.
std::optional<int> kv_window_begin = std::nullopt;
std::optional<int> kv_window_end = std::nullopt;
Expand Down
8 changes: 4 additions & 4 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ class EngineImpl : public Engine {
bool HandleDisaggRequest(Request request) {
DisaggConfig disagg_config = request->generation_cfg->debug_config.disagg_config;
DisaggRequestKind kind = disagg_config.kind;
if (kind == DisaggRequestKind::kPreparePrefill) {
if (kind == DisaggRequestKind::kPrepareReceive) {
// No-op.
return false;
} else if (kind == DisaggRequestKind::kRemotePrefill) {
} else if (kind == DisaggRequestKind::kRemoteSend) {
int input_length = 0;
for (Data input : request->inputs) {
input_length += input->GetLength();
Expand Down Expand Up @@ -586,13 +586,13 @@ class EngineImpl : public Engine {
updated_generation_cfg->n = 1;
request->generation_cfg = GenerationConfig(updated_generation_cfg);
return false;
} else if (kind == DisaggRequestKind::kStartDecode) {
} else if (kind == DisaggRequestKind::kStartGeneration) {
auto it_rstate = estate_->request_states.find(request->id);
CHECK(it_rstate != estate_->request_states.end());
ICHECK(!it_rstate->second->entries.empty());
request = it_rstate->second->entries[0]->request;
CHECK(request->generation_cfg->debug_config.disagg_config.kind ==
DisaggRequestKind::kPreparePrefill);
DisaggRequestKind::kPrepareReceive);
int input_length = 0;
for (Data input : request->inputs) {
input_length += input->GetLength();
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/engine_actions/action.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class EngineAction : public ObjectRef {
* matched length in the prefix cache.
* \return The created action object.
*/
static EngineAction DisaggPreparePrefill(Array<Model> models, EngineConfig engine_config,
static EngineAction DisaggPrepareReceive(Array<Model> models, EngineConfig engine_config,
std::vector<picojson::object> model_configs,
Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback);
Expand All @@ -238,7 +238,7 @@ class EngineAction : public ObjectRef {
* \param device The device of the model for synchronization.
* \return The created action object.
*/
static EngineAction NewRequestPrefillWithKVSend(
static EngineAction DisaggRemoteSend(
Array<Model> models, std::vector<ModelWorkspace> model_workspaces, EngineConfig engine_config,
std::vector<picojson::object> model_configs, Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback, Device device);
Expand Down
11 changes: 5 additions & 6 deletions cpp/serve/engine_actions/action_commons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,10 @@ Array<EngineAction> CreateEngineActions(
if (model_metadata.disaggregation) {
// Insert the disaggregation actions.
Array<EngineAction> disaggregation_actions = {
EngineAction::DisaggPreparePrefill(models, engine_config, model_configs, trace_recorder,
EngineAction::DisaggPrepareReceive(models, engine_config, model_configs, trace_recorder,
request_stream_callback),
EngineAction::NewRequestPrefillWithKVSend(models, model_workspaces, engine_config,
model_configs, trace_recorder,
request_stream_callback, device)};
EngineAction::DisaggRemoteSend(models, model_workspaces, engine_config, model_configs,
trace_recorder, request_stream_callback, device)};
actions.insert(actions.begin(), disaggregation_actions.begin(), disaggregation_actions.end());
}
return actions;
Expand Down Expand Up @@ -302,11 +301,11 @@ void ActionStepPostProcess(Array<Request> requests, EngineState estate, const Ar
}
}

// - For all disaggregation requests with "remote_prefill",
// - For all disaggregation requests with "remote_send",
// if it does not appear in the waiting queue, it means the prefill has been finished.
// In this case, we mark the request as finished.
if (request->generation_cfg->debug_config.disagg_config.kind ==
DisaggRequestKind::kRemotePrefill) {
DisaggRequestKind::kRemoteSend) {
auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request);
if (it == estate->waiting_queue.end()) {
CHECK_EQ(rstate->entries.size(), 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ namespace serve {
* It picks a new request, reserve its KV data locations, and returns the
* KV data locations and the matched prefix length in prefix cache.
*/
class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj {
class DisaggPrepareReceiveActionObj : public BatchPrefillBaseActionObj {
public:
explicit DisaggPreparePrefillActionObj(Array<Model> models, EngineConfig engine_config,
explicit DisaggPrepareReceiveActionObj(Array<Model> models, EngineConfig engine_config,
std::vector<picojson::object> model_configs,
Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback)
Expand Down Expand Up @@ -51,7 +51,7 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj {
}

{
NVTXScopedRange nvtx_scope("DisaggPreparePrefill matching prefix");
NVTXScopedRange nvtx_scope("DisaggPrepareReceive matching prefix");
prefix_matched_length = MatchPrefixCache(estate, &prefill_input);
}

Expand Down Expand Up @@ -199,7 +199,7 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj {
Request request{nullptr};
for (const Request& request_candidate : estate->waiting_queue) {
if (request_candidate->generation_cfg->debug_config.disagg_config.kind ==
DisaggRequestKind::kPreparePrefill) {
DisaggRequestKind::kPrepareReceive) {
request = request_candidate;
break;
}
Expand Down Expand Up @@ -427,11 +427,11 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj {
FRequestStreamCallback request_stream_callback_;
};

EngineAction EngineAction::DisaggPreparePrefill(Array<Model> models, EngineConfig engine_config,
EngineAction EngineAction::DisaggPrepareReceive(Array<Model> models, EngineConfig engine_config,
std::vector<picojson::object> model_configs,
Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback) {
return EngineAction(make_object<DisaggPreparePrefillActionObj>(
return EngineAction(make_object<DisaggPrepareReceiveActionObj>(
std::move(models), std::move(engine_config), std::move(model_configs),
std::move(trace_recorder), std::move(request_stream_callback)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ namespace serve {
* Aside from that, this action sends the computed KV data to remote
* instances after computing the KV data.
*/
class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj {
class DisaggRemoteSendActionObj : public BatchPrefillBaseActionObj {
public:
explicit NewRequestPrefillWithKVSendActionObj(
Array<Model> models, std::vector<ModelWorkspace> model_workspaces, EngineConfig engine_config,
std::vector<picojson::object> model_configs, Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback, Device device)
explicit DisaggRemoteSendActionObj(Array<Model> models,
std::vector<ModelWorkspace> model_workspaces,
EngineConfig engine_config,
std::vector<picojson::object> model_configs,
Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback, Device device)
: BatchPrefillBaseActionObj(std::move(models), std::move(engine_config),
std::move(model_configs), std::move(trace_recorder)),
model_workspaces_(std::move(model_workspaces)),
Expand All @@ -39,7 +41,7 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj {
// - Find the requests in `waiting_queue` that can prefill in this step.
std::vector<PrefillInput> prefill_inputs;
{
NVTXScopedRange nvtx_scope("NewRequestPrefillWithKVSend getting requests");
NVTXScopedRange nvtx_scope("DisaggRemoteSend getting requests");
prefill_inputs = GetRequestStateEntriesToPrefill(estate);
if (prefill_inputs.empty()) {
return {};
Expand All @@ -48,7 +50,7 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj {

int num_rsentries = prefill_inputs.size();
{
NVTXScopedRange nvtx_scope("NewRequestPrefillWithKVSend matching prefix");
NVTXScopedRange nvtx_scope("DisaggRemoteSend matching prefix");
for (int i = 0; i < num_rsentries; ++i) {
MatchPrefixCache(estate, &prefill_inputs[i]);
}
Expand Down Expand Up @@ -183,12 +185,12 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj {
}

// Explicitly filter the waiting queue to only keep the requests
// with disaggregation request kind "kRemotePrefill".
// with disaggregation request kind "kRemoteSend".
std::vector<Request> waiting_queue;
waiting_queue.reserve(estate->waiting_queue.size());
for (Request request : estate->waiting_queue) {
if (request->generation_cfg->debug_config.disagg_config.kind ==
DisaggRequestKind::kRemotePrefill) {
DisaggRequestKind::kRemoteSend) {
waiting_queue.push_back(request);
}
}
Expand Down Expand Up @@ -481,11 +483,11 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj {
TVMStreamHandle compute_stream_ = nullptr;
};

EngineAction EngineAction::NewRequestPrefillWithKVSend(
EngineAction EngineAction::DisaggRemoteSend(
Array<Model> models, std::vector<ModelWorkspace> model_workspaces, EngineConfig engine_config,
std::vector<picojson::object> model_configs, Optional<EventTraceRecorder> trace_recorder,
FRequestStreamCallback request_stream_callback, Device device) {
return EngineAction(make_object<NewRequestPrefillWithKVSendActionObj>(
return EngineAction(make_object<DisaggRemoteSendActionObj>(
std::move(models), std::move(model_workspaces), std::move(engine_config),
std::move(model_configs), std::move(trace_recorder), std::move(request_stream_callback),
device));
Expand Down
2 changes: 2 additions & 0 deletions python/mlc_llm/interface/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mlc_llm.serve.entrypoints import (
debug_entrypoints,
metrics_entrypoints,
microserving_entrypoints,
openai_entrypoints,
)
from mlc_llm.serve.server import ServerContext
Expand Down Expand Up @@ -95,6 +96,7 @@ def serve(

app.include_router(openai_entrypoints.app)
app.include_router(metrics_entrypoints.app)
app.include_router(microserving_entrypoints.app)

server_context.enable_debug = enable_debug

Expand Down
8 changes: 4 additions & 4 deletions python/mlc_llm/protocol/debug_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
class DisaggConfig(BaseModel):
"""The class of metadata used in microserving APIs."""

kind: Optional[Literal["prepare_prefill", "remote_prefill", "start_decode"]] = None
kind: Optional[Literal["prepare_receive", "remote_send", "start_generation"]] = None
# "kv_append_metadata" is base64-encoded and is thus a string.
kv_append_metadata: Optional[str] = None
# "kv_window_begin" and "kv_window_end" denote the KV interval of interests.
# "kv_window_end" supports Python style negative indexing.
# The concrete meaning varies for different special request kind:
# - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes
# - For "prepare_receive", the begin is always 0, and "[0:end]" denotes
# the KV range to prefill on a prefill instance.
# - For "remote_prefill", "[begin:end]" means the KV range to compute prefill
# - For "remote_send", "[begin:end]" means the KV range to compute prefill
# and send to the decode instance.
# - For "start_decode", the end is always None, and "[begin:]" denotes
# - For "start_generation", the end is always None, and "[begin:]" denotes
# the KV range to prefill locally on the decode instance.
kv_window_begin: Optional[int] = None
kv_window_end: Optional[int] = None
Expand Down
76 changes: 76 additions & 0 deletions python/mlc_llm/protocol/microserving_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Protocols in MLC LLM for MicroServing."""

from pydantic import BaseModel

from mlc_llm.protocol.openai_api_protocol import CompletionRequest


class PrepRecvRequest(CompletionRequest):
"""The extra request body for prep_recv request in MicroServing.

Attributes
----------
kv_window_end : int
[0, kv_window_end] denotes the KV range of the prompt to prefill on
a prefill instance.
The entries of this KV range will be allocated on the decode instance.
"""

kv_window_end: int


class PrepRecvResponse(BaseModel):
"""The response body for prep_recv request in MicroServing.

Attributes
----------
prompt_length : int
The length of the request prompt in tokens.

prefix_matched_length : int
The matched common prefix length on the decode instance when
prefix cache is enabled, or 0 if there is no prefix cache.

kv_append_metadata : str
The metadata of the KV range on the destination decode instance.
"""

prompt_length: int
prefix_matched_length: int
kv_append_metadata: str


class RemoteSendRequest(CompletionRequest):
"""The extra request body for remote_send request in MicroServing.

Attributes
----------
kv_window_begin : int
Denote the start of the KV range to prefill.

kv_window_end : int
Denote the end of the KV range to prefill.

kv_append_metadata : str
The metadata of the KV range on the destination decode instance.

dst_group_offset : int
The node group offset of the destination decode instance.
"""

kv_window_begin: int
kv_window_end: int
kv_append_metadata: str
dst_group_offset: int


class StartGenerateRequest(CompletionRequest):
"""The extra request body for start_generate request in MicroServing.

Attributes
----------
kv_window_begin : int
Denote the start of the KV range to prefill on the decode instance.
"""

kv_window_begin: int
Loading
Loading