Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 73 additions & 51 deletions python/mlc_llm/serve/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,19 +975,26 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local
logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = (
[[] for _ in range(n)] if logprobs else None
)
async for response in chatcmpl_generator:
num_prompt_tokens = response.usage.prompt_tokens
num_completion_tokens = response.usage.completion_tokens
for choice in response.choices:
assert isinstance(choice.delta.content, str)
output_texts[choice.index] += choice.delta.content
if choice.finish_reason is not None and finish_reasons[choice.index] is None:
finish_reasons[choice.index] = choice.finish_reason
if choice.logprobs is not None:
assert logprob_results is not None
logprob_results[ # pylint: disable=unsupported-assignment-operation
choice.index
] += choice.logprobs.content
try:
async for response in chatcmpl_generator:
num_prompt_tokens = response.usage.prompt_tokens
num_completion_tokens = response.usage.completion_tokens
for choice in response.choices:
assert isinstance(choice.delta.content, str)
output_texts[choice.index] += choice.delta.content
if choice.finish_reason is not None and finish_reasons[choice.index] is None:
finish_reasons[choice.index] = choice.finish_reason
if choice.logprobs is not None:
assert logprob_results is not None
logprob_results[ # pylint: disable=unsupported-assignment-operation
choice.index
] += choice.logprobs.content
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
logger.error("Error in chat completion with request ID %s: %s", request_id, err)
raise err

assert all(finish_reason is not None for finish_reason in finish_reasons)
use_function_calling, tool_calls_list = engine_base.process_function_call_output(
Expand Down Expand Up @@ -1150,23 +1157,30 @@ async def _handle_chat_completion(
finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]
num_completion_tokens = 0
self.state.record_event(request_id, event="invoke generate")
async for delta_outputs in self._generate(
prompts, generation_cfg, request_id # type: ignore
):
response, num_completion_tokens = engine_base.process_chat_completion_stream_output(
delta_outputs,
request_id,
self.state,
request.model,
generation_cfg,
use_function_calling,
prompt_length,
finish_reasons,
num_completion_tokens,
)
if response is not None:
yield response
self.state.record_event(request_id, event="finish")
try:
async for delta_outputs in self._generate(
prompts, generation_cfg, request_id # type: ignore
):
response, num_completion_tokens = engine_base.process_chat_completion_stream_output(
delta_outputs,
request_id,
self.state,
request.model,
generation_cfg,
use_function_calling,
prompt_length,
finish_reasons,
num_completion_tokens,
)
if response is not None:
yield response
self.state.record_event(request_id, event="finish")
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
logger.error("Error in _handle_chat_completion for request %s: %s", request_id, err)
raise err

async def _handle_completion(
self, request: openai_api_protocol.CompletionRequest, request_id: str
Expand Down Expand Up @@ -1204,28 +1218,35 @@ async def _handle_completion(
num_completion_tokens = 0
finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]
self.state.record_event(request_id, event="invoke generate")
async for delta_outputs in self._generate(
prompt, generation_cfg, request_id # type: ignore
):
response, num_completion_tokens = engine_base.process_completion_stream_output(
delta_outputs,
request_id,
self.state,
request.model,
generation_cfg,
prompt_length,
finish_reasons,
num_completion_tokens,
)
if response is not None:
yield response
try:
async for delta_outputs in self._generate(
prompt, generation_cfg, request_id # type: ignore
):
response, num_completion_tokens = engine_base.process_completion_stream_output(
delta_outputs,
request_id,
self.state,
request.model,
generation_cfg,
prompt_length,
finish_reasons,
num_completion_tokens,
)
if response is not None:
yield response

suffix_response = engine_base.create_completion_suffix_response(
request, request_id, prompt_length, finish_reasons, num_completion_tokens
)
if suffix_response is not None:
yield suffix_response
self.state.record_event(request_id, event="finish")
suffix_response = engine_base.create_completion_suffix_response(
request, request_id, prompt_length, finish_reasons, num_completion_tokens
)
if suffix_response is not None:
yield suffix_response
self.state.record_event(request_id, event="finish")
except (
Exception,
asyncio.CancelledError,
) as err: # pylint: disable=broad-exception-caught
logger.error("Error in _handle_completion for request %s: %s", request_id, err)
raise err

async def _generate(
self,
Expand Down Expand Up @@ -1293,6 +1314,7 @@ async def _generate(
Exception,
asyncio.CancelledError,
) as exception: # pylint: disable=broad-exception-caught
logger.error("Error in _generate for request %s: %s", request_id, exception)
await self.abort(request_id)
raise exception

Expand Down