diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 732f575da2..a05fa5274a 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -70,12 +70,12 @@ Result DisaggConfig::FromJSON(const picojson::object& config) { DisaggConfig res; std::optional kind = json::LookupOptional(config, "kind"); if (kind.has_value()) { - if (kind.value() == "prepare_prefill") { - res.kind = DisaggRequestKind::kPreparePrefill; - } else if (kind.value() == "remote_prefill") { - res.kind = DisaggRequestKind::kRemotePrefill; - } else if (kind.value() == "start_decode") { - res.kind = DisaggRequestKind::kStartDecode; + if (kind.value() == "prepare_receive") { + res.kind = DisaggRequestKind::kPrepareReceive; + } else if (kind.value() == "remote_send") { + res.kind = DisaggRequestKind::kRemoteSend; + } else if (kind.value() == "start_generation") { + res.kind = DisaggRequestKind::kStartGeneration; } else { return TResult::Error("Unknown disaggregation request kind " + kind.value()); } @@ -125,16 +125,16 @@ Result DisaggConfig::FromJSON(const picojson::object& config) { picojson::object DisaggConfig::AsJSON() const { picojson::object config; switch (kind) { - case DisaggRequestKind::kPreparePrefill: { - config["kind"] = picojson::value("prepare_prefill"); + case DisaggRequestKind::kPrepareReceive: { + config["kind"] = picojson::value("prepare_receive"); break; } - case DisaggRequestKind::kRemotePrefill: { - config["kind"] = picojson::value("remote_prefill"); + case DisaggRequestKind::kRemoteSend: { + config["kind"] = picojson::value("remote_send"); break; } - case DisaggRequestKind::kStartDecode: { - config["kind"] = picojson::value("start_decode"); + case DisaggRequestKind::kStartGeneration: { + config["kind"] = picojson::value("start_generation"); break; } default: diff --git a/cpp/serve/config.h b/cpp/serve/config.h index af0fba04c5..6711c2e867 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -48,9 +48,9 @@ enum class SpecialRequestKind : int { enum class DisaggRequestKind : int { kNone = 0, - kPreparePrefill = 1, - kRemotePrefill = 2, - kStartDecode = 3, + kPrepareReceive = 1, + kRemoteSend = 2, + kStartGeneration = 3, }; /*! \brief Controls the behavior of inference with grammar constraint. */ @@ -70,11 +70,11 @@ class DisaggConfig { // "kv_window_begin" and "kv_window_end" denote the KV interval of interests. // "kv_window_end" supports Python style negative indexing. // The concrete meaning varies for different special request kind: - // - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes + // - For "prepare_receive", the begin is always 0, and "[0:end]" denotes // the KV range to prefill on a prefill instance. - // - For "remote_prefill", "[begin:end]" means the KV range to compute prefill + // - For "remote_send", "[begin:end]" means the KV range to compute prefill // and send to the decode instance. - // - For "start_decode", the end is always nullopt, and "[begin:]" denotes + // - For "start_generation", the end is always nullopt, and "[begin:]" denotes // the KV range to prefill locally on the decode instance. std::optional kv_window_begin = std::nullopt; std::optional kv_window_end = std::nullopt; diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index b0e0bb9eb6..19cad78dc9 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -548,10 +548,10 @@ class EngineImpl : public Engine { bool HandleDisaggRequest(Request request) { DisaggConfig disagg_config = request->generation_cfg->debug_config.disagg_config; DisaggRequestKind kind = disagg_config.kind; - if (kind == DisaggRequestKind::kPreparePrefill) { + if (kind == DisaggRequestKind::kPrepareReceive) { // No-op. return false; - } else if (kind == DisaggRequestKind::kRemotePrefill) { + } else if (kind == DisaggRequestKind::kRemoteSend) { int input_length = 0; for (Data input : request->inputs) { input_length += input->GetLength(); @@ -586,13 +586,13 @@ class EngineImpl : public Engine { updated_generation_cfg->n = 1; request->generation_cfg = GenerationConfig(updated_generation_cfg); return false; - } else if (kind == DisaggRequestKind::kStartDecode) { + } else if (kind == DisaggRequestKind::kStartGeneration) { auto it_rstate = estate_->request_states.find(request->id); CHECK(it_rstate != estate_->request_states.end()); ICHECK(!it_rstate->second->entries.empty()); request = it_rstate->second->entries[0]->request; CHECK(request->generation_cfg->debug_config.disagg_config.kind == - DisaggRequestKind::kPreparePrefill); + DisaggRequestKind::kPrepareReceive); int input_length = 0; for (Data input : request->inputs) { input_length += input->GetLength(); diff --git a/cpp/serve/engine_actions/action.h b/cpp/serve/engine_actions/action.h index 1e2ad75589..210cea6969 100644 --- a/cpp/serve/engine_actions/action.h +++ b/cpp/serve/engine_actions/action.h @@ -220,7 +220,7 @@ class EngineAction : public ObjectRef { * matched length in the prefix cache. * \return The created action object. */ - static EngineAction DisaggPreparePrefill(Array models, EngineConfig engine_config, + static EngineAction DisaggPrepareReceive(Array models, EngineConfig engine_config, std::vector model_configs, Optional trace_recorder, FRequestStreamCallback request_stream_callback); @@ -238,7 +238,7 @@ class EngineAction : public ObjectRef { * \param device The device of the model for synchronization. * \return The created action object. */ - static EngineAction NewRequestPrefillWithKVSend( + static EngineAction DisaggRemoteSend( Array models, std::vector model_workspaces, EngineConfig engine_config, std::vector model_configs, Optional trace_recorder, FRequestStreamCallback request_stream_callback, Device device); diff --git a/cpp/serve/engine_actions/action_commons.cc b/cpp/serve/engine_actions/action_commons.cc index e635bbb976..def48869b9 100644 --- a/cpp/serve/engine_actions/action_commons.cc +++ b/cpp/serve/engine_actions/action_commons.cc @@ -122,11 +122,10 @@ Array CreateEngineActions( if (model_metadata.disaggregation) { // Insert the disaggregation actions. Array disaggregation_actions = { - EngineAction::DisaggPreparePrefill(models, engine_config, model_configs, trace_recorder, + EngineAction::DisaggPrepareReceive(models, engine_config, model_configs, trace_recorder, request_stream_callback), - EngineAction::NewRequestPrefillWithKVSend(models, model_workspaces, engine_config, - model_configs, trace_recorder, - request_stream_callback, device)}; + EngineAction::DisaggRemoteSend(models, model_workspaces, engine_config, model_configs, + trace_recorder, request_stream_callback, device)}; actions.insert(actions.begin(), disaggregation_actions.begin(), disaggregation_actions.end()); } return actions; @@ -302,11 +301,11 @@ void ActionStepPostProcess(Array requests, EngineState estate, const Ar } } - // - For all disaggregation requests with "remote_prefill", + // - For all disaggregation requests with "remote_send", // if it does not appear in the waiting queue, it means the prefill has been finished. // In this case, we mark the request as finished. if (request->generation_cfg->debug_config.disagg_config.kind == - DisaggRequestKind::kRemotePrefill) { + DisaggRequestKind::kRemoteSend) { auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request); if (it == estate->waiting_queue.end()) { CHECK_EQ(rstate->entries.size(), 1); diff --git a/cpp/serve/engine_actions/prefill_prepare.cc b/cpp/serve/engine_actions/disagg_prepare_recv.cc similarity index 98% rename from cpp/serve/engine_actions/prefill_prepare.cc rename to cpp/serve/engine_actions/disagg_prepare_recv.cc index a28d33125b..73240af817 100644 --- a/cpp/serve/engine_actions/prefill_prepare.cc +++ b/cpp/serve/engine_actions/disagg_prepare_recv.cc @@ -18,9 +18,9 @@ namespace serve { * It picks a new request, reserve its KV data locations, and returns the * KV data locations and the matched prefix length in prefix cache. */ -class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj { +class DisaggPrepareReceiveActionObj : public BatchPrefillBaseActionObj { public: - explicit DisaggPreparePrefillActionObj(Array models, EngineConfig engine_config, + explicit DisaggPrepareReceiveActionObj(Array models, EngineConfig engine_config, std::vector model_configs, Optional trace_recorder, FRequestStreamCallback request_stream_callback) @@ -51,7 +51,7 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj { } { - NVTXScopedRange nvtx_scope("DisaggPreparePrefill matching prefix"); + NVTXScopedRange nvtx_scope("DisaggPrepareReceive matching prefix"); prefix_matched_length = MatchPrefixCache(estate, &prefill_input); } @@ -199,7 +199,7 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj { Request request{nullptr}; for (const Request& request_candidate : estate->waiting_queue) { if (request_candidate->generation_cfg->debug_config.disagg_config.kind == - DisaggRequestKind::kPreparePrefill) { + DisaggRequestKind::kPrepareReceive) { request = request_candidate; break; } @@ -427,11 +427,11 @@ class DisaggPreparePrefillActionObj : public BatchPrefillBaseActionObj { FRequestStreamCallback request_stream_callback_; }; -EngineAction EngineAction::DisaggPreparePrefill(Array models, EngineConfig engine_config, +EngineAction EngineAction::DisaggPrepareReceive(Array models, EngineConfig engine_config, std::vector model_configs, Optional trace_recorder, FRequestStreamCallback request_stream_callback) { - return EngineAction(make_object( + return EngineAction(make_object( std::move(models), std::move(engine_config), std::move(model_configs), std::move(trace_recorder), std::move(request_stream_callback))); } diff --git a/cpp/serve/engine_actions/new_request_prefill_with_kv_send.cc b/cpp/serve/engine_actions/disagg_remote_send.cc similarity index 96% rename from cpp/serve/engine_actions/new_request_prefill_with_kv_send.cc rename to cpp/serve/engine_actions/disagg_remote_send.cc index f53fbc8837..18d438eb7f 100644 --- a/cpp/serve/engine_actions/new_request_prefill_with_kv_send.cc +++ b/cpp/serve/engine_actions/disagg_remote_send.cc @@ -16,12 +16,14 @@ namespace serve { * Aside from that, this action sends the computed KV data to remote * instances after computing the KV data. */ -class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj { +class DisaggRemoteSendActionObj : public BatchPrefillBaseActionObj { public: - explicit NewRequestPrefillWithKVSendActionObj( - Array models, std::vector model_workspaces, EngineConfig engine_config, - std::vector model_configs, Optional trace_recorder, - FRequestStreamCallback request_stream_callback, Device device) + explicit DisaggRemoteSendActionObj(Array models, + std::vector model_workspaces, + EngineConfig engine_config, + std::vector model_configs, + Optional trace_recorder, + FRequestStreamCallback request_stream_callback, Device device) : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config), std::move(model_configs), std::move(trace_recorder)), model_workspaces_(std::move(model_workspaces)), @@ -39,7 +41,7 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj { // - Find the requests in `waiting_queue` that can prefill in this step. std::vector prefill_inputs; { - NVTXScopedRange nvtx_scope("NewRequestPrefillWithKVSend getting requests"); + NVTXScopedRange nvtx_scope("DisaggRemoteSend getting requests"); prefill_inputs = GetRequestStateEntriesToPrefill(estate); if (prefill_inputs.empty()) { return {}; @@ -48,7 +50,7 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj { int num_rsentries = prefill_inputs.size(); { - NVTXScopedRange nvtx_scope("NewRequestPrefillWithKVSend matching prefix"); + NVTXScopedRange nvtx_scope("DisaggRemoteSend matching prefix"); for (int i = 0; i < num_rsentries; ++i) { MatchPrefixCache(estate, &prefill_inputs[i]); } @@ -183,12 +185,12 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj { } // Explicitly filter the waiting queue to only keep the requests - // with disaggregation request kind "kRemotePrefill". + // with disaggregation request kind "kRemoteSend". std::vector waiting_queue; waiting_queue.reserve(estate->waiting_queue.size()); for (Request request : estate->waiting_queue) { if (request->generation_cfg->debug_config.disagg_config.kind == - DisaggRequestKind::kRemotePrefill) { + DisaggRequestKind::kRemoteSend) { waiting_queue.push_back(request); } } @@ -481,11 +483,11 @@ class NewRequestPrefillWithKVSendActionObj : public BatchPrefillBaseActionObj { TVMStreamHandle compute_stream_ = nullptr; }; -EngineAction EngineAction::NewRequestPrefillWithKVSend( +EngineAction EngineAction::DisaggRemoteSend( Array models, std::vector model_workspaces, EngineConfig engine_config, std::vector model_configs, Optional trace_recorder, FRequestStreamCallback request_stream_callback, Device device) { - return EngineAction(make_object( + return EngineAction(make_object( std::move(models), std::move(model_workspaces), std::move(engine_config), std::move(model_configs), std::move(trace_recorder), std::move(request_stream_callback), device)); diff --git a/python/mlc_llm/interface/serve.py b/python/mlc_llm/interface/serve.py index 5821cd8563..c00ed1adc5 100644 --- a/python/mlc_llm/interface/serve.py +++ b/python/mlc_llm/interface/serve.py @@ -11,6 +11,7 @@ from mlc_llm.serve.entrypoints import ( debug_entrypoints, metrics_entrypoints, + microserving_entrypoints, openai_entrypoints, ) from mlc_llm.serve.server import ServerContext @@ -95,6 +96,7 @@ def serve( app.include_router(openai_entrypoints.app) app.include_router(metrics_entrypoints.app) + app.include_router(microserving_entrypoints.app) server_context.enable_debug = enable_debug diff --git a/python/mlc_llm/protocol/debug_protocol.py b/python/mlc_llm/protocol/debug_protocol.py index 057a0086b6..f233ce69a2 100644 --- a/python/mlc_llm/protocol/debug_protocol.py +++ b/python/mlc_llm/protocol/debug_protocol.py @@ -8,17 +8,17 @@ class DisaggConfig(BaseModel): """The class of metadata used in microserving APIs.""" - kind: Optional[Literal["prepare_prefill", "remote_prefill", "start_decode"]] = None + kind: Optional[Literal["prepare_receive", "remote_send", "start_generation"]] = None # "kv_append_metadata" is base64-encoded and is thus a string. kv_append_metadata: Optional[str] = None # "kv_window_begin" and "kv_window_end" denote the KV interval of interests. # "kv_window_end" supports Python style negative indexing. # The concrete meaning varies for different special request kind: - # - For "prepare_prefill", the begin is always 0, and "[0:end]" denotes + # - For "prepare_receive", the begin is always 0, and "[0:end]" denotes # the KV range to prefill on a prefill instance. - # - For "remote_prefill", "[begin:end]" means the KV range to compute prefill + # - For "remote_send", "[begin:end]" means the KV range to compute prefill # and send to the decode instance. - # - For "start_decode", the end is always None, and "[begin:]" denotes + # - For "start_generation", the end is always None, and "[begin:]" denotes # the KV range to prefill locally on the decode instance. kv_window_begin: Optional[int] = None kv_window_end: Optional[int] = None diff --git a/python/mlc_llm/protocol/microserving_protocol.py b/python/mlc_llm/protocol/microserving_protocol.py new file mode 100644 index 0000000000..fa8cdd63c6 --- /dev/null +++ b/python/mlc_llm/protocol/microserving_protocol.py @@ -0,0 +1,76 @@ +"""Protocols in MLC LLM for MicroServing.""" + +from pydantic import BaseModel + +from mlc_llm.protocol.openai_api_protocol import CompletionRequest + + +class PrepRecvRequest(CompletionRequest): + """The extra request body for prep_recv request in MicroServing. + + Attributes + ---------- + kv_window_end : int + [0, kv_window_end] denotes the KV range of the prompt to prefill on + a prefill instance. + The entries of this KV range will be allocated on the decode instance. + """ + + kv_window_end: int + + +class PrepRecvResponse(BaseModel): + """The response body for prep_recv request in MicroServing. + + Attributes + ---------- + prompt_length : int + The length of the request prompt in tokens. + + prefix_matched_length : int + The matched common prefix length on the decode instance when + prefix cache is enabled, or 0 if there is no prefix cache. + + kv_append_metadata : str + The metadata of the KV range on the destination decode instance. + """ + + prompt_length: int + prefix_matched_length: int + kv_append_metadata: str + + +class RemoteSendRequest(CompletionRequest): + """The extra request body for remote_send request in MicroServing. + + Attributes + ---------- + kv_window_begin : int + Denote the start of the KV range to prefill. + + kv_window_end : int + Denote the end of the KV range to prefill. + + kv_append_metadata : str + The metadata of the KV range on the destination decode instance. + + dst_group_offset : int + The node group offset of the destination decode instance. + """ + + kv_window_begin: int + kv_window_end: int + kv_append_metadata: str + dst_group_offset: int + + +class StartGenerateRequest(CompletionRequest): + """The extra request body for start_generate request in MicroServing. + + Attributes + ---------- + kv_window_begin : int + Denote the start of the KV range to prefill on the decode instance. + """ + + kv_window_begin: int diff --git a/python/mlc_llm/router/router.py b/python/mlc_llm/router/router.py index 25d6b0db43..3ab8a8e210 100644 --- a/python/mlc_llm/router/router.py +++ b/python/mlc_llm/router/router.py @@ -5,11 +5,12 @@ import threading from typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Tuple -import aiohttp # pylint: disable=import-outside-toplevel,import-error +import aiohttp # pylint: disable=import-error import tvm -from mlc_llm.protocol import debug_protocol, openai_api_protocol +from mlc_llm.protocol import openai_api_protocol from mlc_llm.serve import EngineConfig, PopenServer +from mlc_llm.serve.entrypoints import microserving_entrypoints from mlc_llm.tokenizers import Tokenizer @@ -20,36 +21,43 @@ def __init__( self, model: str, model_lib: Optional[str] = None, - hosts: List[str] = ["127.0.0.1"], - ports: List[int] = [8080], - num_gpus: List[int] = [1], + hosts: Optional[List[str]] = None, + ports: Optional[List[int]] = None, + num_gpus: Optional[List[int]] = None, enable_prefix_cache: bool = False, router_mode: Literal["disagg", "round-robin"] = "disagg", pd_balance_factor: float = 0.0, - ): # pylint: disable=too-many-arguments,too-many-locals,dangerous-default-value + ): # pylint: disable=too-many-arguments,too-many-locals """ Spawn len(host_list) server endpoints with Popen. """ + if hosts is None: + hosts = ["127.0.0.1"] + if ports is None: + ports = [8080] + if num_gpus is None: + num_gpus = [1] + self.router_mode = router_mode self.pd_balance_factor = pd_balance_factor # Get endpoint urls - self.num_endpoints = len(hosts) - assert self.num_endpoints == len(ports) == len(num_gpus) + self.num_servers = len(hosts) + assert self.num_servers == len(ports) == len(num_gpus) self.hosts = hosts self.ports = ports - self.endpoints = [] - for i in range(self.num_endpoints): - self.endpoints.append(f"http://{hosts[i]}:{ports[i]}/v1/completions") + self.server_urls = [] + for i in range(self.num_servers): + self.server_urls.append(f"http://{hosts[i]}:{ports[i]}") # Misc self.headers = {"Content-Type": "application/json"} - self.num_running_requests = [0] * self.num_endpoints + self.num_running_requests = [0] * self.num_servers # Call nvshmem_init here to get uid, then pass to env variables to server.start() below f_init_nvshmem_uid = tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid") uid = list(f_init_nvshmem_uid()) - # Start underlying endpoints concurrently. Otherwise 1 server cannot start on its own + # Start underlying servers concurrently. Otherwise 1 server cannot start on its own # since initializing nvhsmem world requires all GPUs. self.servers: List[PopenServer] = [] @@ -83,7 +91,7 @@ def start_server(i: int): threads = [] num_used_gpus = 0 - for i in range(self.num_endpoints): + for i in range(self.num_servers): thread = threading.Thread( target=start_server, args=[i], @@ -96,7 +104,7 @@ def start_server(i: int): self.tokenizer = Tokenizer(model) def terminate(self): - """Terminate the underlying endpoints""" + """Terminate the underlying servers""" for server in self.servers: server.terminate() @@ -142,14 +150,14 @@ async def _handle_completion_round_robin( endpoints with round-robin scheduling at a request level. """ # Round robin - cur_endpoint = self._pick_endpoint(range(self.num_endpoints)) + cur_endpoint = self._pick_endpoint(range(self.num_servers)) self.num_running_requests[cur_endpoint] += 1 payload = request.model_dump() async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True ) as session: async with session.post( - self.endpoints[cur_endpoint], json=payload, headers=self.headers + self.server_urls[cur_endpoint], json=payload, headers=self.headers ) as response: assert response.status == 200, await response.text() completed = False @@ -210,7 +218,7 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals original_request.user = request_id # Arbitrarily determine server 0 is P, other servers are D prefill_server_id = 0 - decode_server_id = self._pick_endpoint(range(1, self.num_endpoints)) + decode_server_id = self._pick_endpoint(range(1, self.num_servers)) # Add a debugConfig if not present if original_request.debug_config is None: @@ -233,23 +241,17 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals completed = False while not completed: # 1. Ask D to prepare metadata - prepare_request = original_request.model_copy() - prepare_request.debug_config.disagg_config = debug_protocol.DisaggConfig( - kind="prepare_prefill", - kv_window_begin=0, # always zero for prepare_prefill - kv_window_end=kv_window_end, - ) - prepare_request.stream_options = openai_api_protocol.StreamOptions( - include_usage=True + prep_recv_request = microserving_entrypoints.PrepRecvRequest( + **original_request.model_dump(), kv_window_end=kv_window_end ) ( prompt_length, prefix_matched_length, kv_append_metadata_base64, - ) = await self.send_decode_prepare( + ) = await self.send_prepare_receive( session=session, - prepare_request=prepare_request, - decode_endpoint=self.endpoints[decode_server_id], + request=prep_recv_request, + server_url=self.server_urls[decode_server_id], ) kv_window_end = ( @@ -261,42 +263,36 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals # KV transfer has finished prefilling and transferring the KV of # prompt[prefix_matched_length:kv_window_end]. So D is ready to decode. if prefix_matched_length < kv_window_end: - prefill_request = original_request.model_copy() - prefill_request.stream_options = openai_api_protocol.StreamOptions( - include_usage=True - ) - prefill_request.debug_config.disagg_config = debug_protocol.DisaggConfig( - kind="remote_prefill", + remote_send_request = microserving_entrypoints.RemoteSendRequest( + **original_request.model_dump(), kv_window_begin=prefix_matched_length, kv_window_end=kv_window_end, kv_append_metadata=kv_append_metadata_base64, dst_group_offset=self.device_id_starts[decode_server_id], ) - await self.send_prefill( + await self.send_remote_send( session=session, - prefill_request=prefill_request, - prefill_endpoint=self.endpoints[prefill_server_id], + request=remote_send_request, + server_url=self.server_urls[prefill_server_id], ) # 3. Start decoding, receive and yield back response as a normal request # The kv window passed through denotes the range to prefill on the # decode server, which should be [-1:] here. - decode_request = original_request.model_copy() - decode_request.debug_config.disagg_config = debug_protocol.DisaggConfig( - kind="start_decode", + start_generate_request = microserving_entrypoints.StartGenerateRequest( + **original_request.model_dump(), kv_window_begin=kv_window_end, ) - async for response in self.send_decode( + async for response in self.send_start_generate( session=session, - decode_request=decode_request, - decode_endpoint=self.endpoints[decode_server_id], + request=start_generate_request, + server_url=self.server_urls[decode_server_id], ): - response_json = response.dict() - if response_json["choices"]: - reason = response_json["choices"][0]["finish_reason"] - if reason == "preempt": + if len(response.choices) > 0: + finish_reason = response.choices[0].finish_reason + if finish_reason == "preempt": break - if reason is not None: + if finish_reason is not None: completed = True yield response except Exception as e: @@ -304,11 +300,11 @@ async def _handle_completion_disagg( # pylint: disable=too-many-locals raise e self.num_running_requests[decode_server_id] -= 1 - async def send_decode_prepare( + async def send_prepare_receive( self, session: aiohttp.ClientSession, - prepare_request: openai_api_protocol.CompletionRequest, - decode_endpoint: str, + request: openai_api_protocol.CompletionRequest, + server_url: str, ) -> Tuple[int, int, str]: """ Performs step 1 of disaggregated serving: ask D to prepare metadata. @@ -319,85 +315,60 @@ async def send_decode_prepare( i.e. prompt[0:prefix_matched_length] is the matched prefix - kv_append_metadata_base64: str, info about KV append encoded in base64 string """ - # Send request to D and get metadata - async with session.post(decode_endpoint, json=prepare_request.model_dump()) as response: + # Send request to the decode server for receive preparation. + # Get the prompt length, matched prefix length and the KV metadata. + async with session.post( + server_url + "/microserving/prep_recv", + json=request.model_dump(), + headers=self.headers, + ) as response: assert response.status == 200, await response.text() - # Expect decode to only return a single usage chunk - data = None - async for chunk in response.content: - if prepare_request.stream: - chunk = chunk.strip() - if not chunk or chunk == b"\n": - continue - # Get rid of the prefix "data: " and suffix "\n" - raw_data = chunk[6:].strip() - if raw_data == b"[DONE]": - continue - assert data is None, ( - f"Expecting only one effective chunk response. " - f"data: {data}, current={json.loads(raw_data)}" - ) - data = json.loads(raw_data) - else: - data = await response.json() - - assert "extra" in data["usage"] - assert "prefix_matched_length" in data["usage"]["extra"] - assert "kv_append_metadata" in data["usage"]["extra"] + data = await response.json() return ( - data["usage"]["extra"]["prompt_length"], - data["usage"]["extra"]["prefix_matched_length"], - data["usage"]["extra"]["kv_append_metadata"], + data["prompt_length"], + data["prefix_matched_length"], + data["kv_append_metadata"], ) - async def send_prefill( + async def send_remote_send( self, session: aiohttp.ClientSession, - prefill_request: openai_api_protocol.CompletionRequest, - prefill_endpoint: str, + request: openai_api_protocol.CompletionRequest, + server_url: str, ) -> None: """ Performs step 2 of disaggregated serving: ask P to prefill and transfer KV to D. P returns an empty chunk to acknowledge completion. """ # Send request to P and get ack - async with session.post(prefill_endpoint, json=prefill_request.model_dump()) as response: + async with session.post( + server_url + "/microserving/remote_send", + json=request.model_dump(), + headers=self.headers, + ) as response: assert response.status == 200, await response.text() - # Expect decode to only return an empty chunk - data = None - async for chunk in response.content: - if prefill_request.stream: - chunk = chunk.strip() - if not chunk or chunk == b"\n": - continue - # Get rid of the prefix "data: " and suffix "\n" - raw_data = chunk[6:].strip() - if raw_data == b"[DONE]": - continue - assert data is None, "Expecting only one effective chunk response." - data = json.loads(raw_data) - else: - data = await response.json() - - assert "extra" in data["usage"] - return + await response.json() - async def send_decode( # pylint: disable=fixme + async def send_start_generate( self, session: aiohttp.ClientSession, - decode_request: openai_api_protocol.CompletionRequest, - decode_endpoint: str, + request: openai_api_protocol.CompletionRequest, + server_url: str, ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]: """ Performs step 3 of disaggregated serving: ask D to decode and return normal response. """ + # pylint: disable=fixme # Todo: return string directly to reduce str->json->str roundtrip overhead + # pylint: enable=fixme async with session.post( - decode_endpoint, json=decode_request.model_dump(), headers=self.headers + server_url + "/microserving/start_generate", + json=request.model_dump(), + headers=self.headers, ) as response: assert response.status == 200, await response.text() - if decode_request.stream: + if request.stream: async for chunk in response.content: # Convert raw bytes to CompletionResponse chunk = chunk.strip() diff --git a/python/mlc_llm/serve/entrypoints/__init__.py b/python/mlc_llm/serve/entrypoints/__init__.py index c846cefe15..3af7104b80 100644 --- a/python/mlc_llm/serve/entrypoints/__init__.py +++ b/python/mlc_llm/serve/entrypoints/__init__.py @@ -1,3 +1,8 @@ """The entrypoints for MLC LLM server.""" -from . import debug_entrypoints, metrics_entrypoints, openai_entrypoints +from . import ( + debug_entrypoints, + metrics_entrypoints, + microserving_entrypoints, + openai_entrypoints, +) diff --git a/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py b/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py new file mode 100644 index 0000000000..a9a062e57a --- /dev/null +++ b/python/mlc_llm/serve/entrypoints/microserving_entrypoints.py @@ -0,0 +1,75 @@ +"""MicroServing server entrypoints in MLC LLM""" + +import fastapi + +from mlc_llm.protocol.debug_protocol import DisaggConfig +from mlc_llm.protocol.microserving_protocol import ( + PrepRecvRequest, + PrepRecvResponse, + RemoteSendRequest, + StartGenerateRequest, +) +from mlc_llm.protocol.openai_api_protocol import StreamOptions + +from .openai_entrypoints import request_completion + +app = fastapi.APIRouter() + + +################ MicroServing Endpoints ################ + + +@app.post("/microserving/prep_recv") +async def prep_recv(request: PrepRecvRequest, raw_request: fastapi.Request) -> PrepRecvResponse: + """Handle the microserving request for receive preparation. + Match the prompt in the prefix cache (when enabled), + allocate entries in the KV cache to prepare receiving the KV data of the prompt. + Return the prompt length, matched prefix length and the allocated KV entry metadata. + """ + request.debug_config.disagg_config = DisaggConfig( + kind="prepare_receive", + kv_window_begin=0, # always zero for prepare_receive + kv_window_end=request.kv_window_end, + ) + request.stream_options = StreamOptions(include_usage=True) + request.stream = False + + response = await request_completion(request=request, raw_request=raw_request) + assert response.usage is not None + assert response.usage.extra is not None + assert "prompt_length" in response.usage.extra + assert "prefix_matched_length" in response.usage.extra + assert "kv_append_metadata" in response.usage.extra + return PrepRecvResponse( + prompt_length=response.usage.extra["prompt_length"], + prefix_matched_length=response.usage.extra["prefix_matched_length"], + kv_append_metadata=response.usage.extra["kv_append_metadata"], + ) + + +@app.post("/microserving/remote_send") +async def remote_send(request: RemoteSendRequest, raw_request: fastapi.Request): + """Compute and generate the KV data of the prompt in the specified KV window. + Send the KV data to the destination server.""" + request.debug_config.disagg_config = DisaggConfig( + kind="remote_send", + kv_window_begin=request.kv_window_begin, + kv_window_end=request.kv_window_end, + kv_append_metadata=request.kv_append_metadata, + dst_group_offset=request.dst_group_offset, + ) + request.stream_options = StreamOptions(include_usage=True) + request.stream = False + + await request_completion(request=request, raw_request=raw_request) + return {} + + +@app.post("/microserving/start_generate") +async def start_generate(request: StartGenerateRequest, raw_request: fastapi.Request): + """Prefill the prompt in the specified KV window, and start decode.""" + request.debug_config.disagg_config = DisaggConfig( + kind="start_generation", + kv_window_begin=request.kv_window_begin, + ) + return await request_completion(request=request, raw_request=raw_request) diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index dd4c517d97..052ba32e36 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -113,7 +113,6 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if choice.logprobs is not None: logprob_results[choice.index] = choice.logprobs - assert all(finish_reason is not None for finish_reason in finish_reasons) return engine_base.wrap_completion_response( request_id=request_id, model=request.model, diff --git a/python/mlc_llm/serve/server/popen_server.py b/python/mlc_llm/serve/server/popen_server.py index 5411828e31..230cbd05ae 100644 --- a/python/mlc_llm/serve/server/popen_server.py +++ b/python/mlc_llm/serve/server/popen_server.py @@ -122,7 +122,6 @@ def start( # pylint: disable=too-many-branches,too-many-statements final_env = os.environ.copy() for key, value in extra_env.items(): final_env[key] = value - print(f"cmd = {cmd}") self._proc = subprocess.Popen( # pylint: disable=consider-using-with cmd, cwd=process_path, env=final_env )