diff --git a/serve/mlc_serve/engine/async_connector.py b/serve/mlc_serve/engine/async_connector.py index 21ed3bdec5..888cd81447 100644 --- a/serve/mlc_serve/engine/async_connector.py +++ b/serve/mlc_serve/engine/async_connector.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import AsyncIterator +from typing import AsyncIterator, Any from .base import ( InferenceEngine, @@ -15,7 +15,9 @@ class TextGenerationError(Exception): - pass + def __init__(self, error: Any) -> None: + self.error = error + super().__init__(error) class AsyncEngineConnector: diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index 7be2f6e63c..ab58db118a 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from dataclasses import dataclass, field from enum import Enum -from typing import Optional, List +from typing import List, Callable, Any, Optional import json import inspect from .sampling_params import SamplingParams, SamplingType @@ -41,10 +43,10 @@ def get_engine_config(dict_config, enable_check = True): # assert engine_config.max_num_sequences > 0 # assert engine_config.max_num_sequences * engine_config.max_input_len == engine_config.max_num_batched_tokens - assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) + assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0) assert engine_config.max_decode_steps > engine_config.min_decode_steps assert engine_config.prompt_allocate_ratio > 0 - + return engine_config @dataclass @@ -75,20 +77,32 @@ class FinishReason(Enum): Length = "length" Cancelled = "cancelled" +# A single token. +Token = List[int] + +@dataclass +class ValidationError: + msg: str + +# The type signature of the token validation callback. +ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError] @dataclass class Request: request_id: RequestId messages: list[ChatMessage] - # Number of sequences to generate num_sequences: int = 1 # TODO: should `best_of` be handled in the serving layer? best_of: int = None - + # Options for sampling. sampling_params: SamplingParams = field(default_factory=SamplingParams) + # Options for stopping. stopping_criteria: StoppingCriteria = field(default_factory=StoppingCriteria) + # Options for debugging. debug_options: DebugOptions = field(default_factory=DebugOptions) + # Perform request validation post-tokenization, used by the HTTP layer to control validation. + validate_tokens: Optional[ValidateTokensCallback] = None def __post_init__(self): if self.best_of is None: @@ -132,7 +146,11 @@ class RequestOutput: # TODO: reconsider the place to put this number # Only set for outputs with valid sequence otuputs num_prompt_tokens: Optional[int] = None - + # TODO(@jroesch): We should generalize the type here so we are allowed to return more structured information + # for logging/user output. + # + # Right now I am abusing dynamic typing by putting the ValidationError in here. + # I would prefer to unblock ourselves then figure this one out right now error: Optional[str] = None @property @@ -209,6 +227,7 @@ class RequestState: stopping_criteria: StoppingCriteria debug_options: DebugOptions is_ended: bool = False + validation_err: Optional[ValidationError] = None def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): if stopping_criteria.stop_sequences: @@ -226,4 +245,4 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended): output_text = output_text[:output_text.find(t) + len(t)] is_ended = True break - return output_text, delta, is_ended \ No newline at end of file + return output_text, delta, is_ended diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index c26f0a90d0..bfe3310ffd 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -5,7 +5,7 @@ import multiprocessing import queue from threading import Lock -from typing import Callable +from typing import Callable, Optional from .base import ( InferenceStepResult, @@ -84,6 +84,8 @@ def add(self, requests: list[Request]): # TODO: verify that request id is unique if req.num_sequences > 1: raise RuntimeError("num_sequences > 1 is not supported for now") + + # If the request violates the tokenization, this returns None, so skip. state = self._get_new_request_state(req) new_request_states.append(state) @@ -200,6 +202,10 @@ def _get_new_request_state(self, request: Request) -> RequestState: prompt_tokens = self.tokenizer.encode(prompt) + validation_err = None + if request.validate_tokens is not None: + validation_err = request.validate_tokens(request, prompt_tokens) + return RequestState( request_id=request.request_id, token_ids=prompt_tokens, @@ -209,6 +215,7 @@ def _get_new_request_state(self, request: Request) -> RequestState: stopping_criteria=request.stopping_criteria, debug_options=request.debug_options, output_text="", + validation_err=validation_err, ) def _decode_last_output(self, state: RequestState) -> str: diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 06eab994ee..82b0b6a8a8 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -7,14 +7,14 @@ from collections import deque from dataclasses import dataclass from threading import Condition, Lock, Thread -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, Any from .base import FinishReason, RequestId, RequestState from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId +import structlog logger = logging.getLogger(__name__) - @dataclass class ShutdownCommand: pass @@ -77,7 +77,15 @@ def __init__( def add(self, request_states: list[RequestState]): with self.queue_lock: - self.queue.extend(request_states) + # States which have been invalidated should never be added, directly + # cancel them instead. + valid_states = [] + for request_state in request_states: + if request_state.validation_err is not None: + self.cancelled_requests.append(request_state) + else: + valid_states.append(request_state) + self.queue.extend(valid_states) self.has_new_requests.notify_all() def cancel(self, request_id: RequestId): @@ -102,7 +110,7 @@ def wait_for_request(self, timeout_seconds=None) -> bool: ) def has_pending_requests(self) -> bool: - return self.queue or self.current_batch + return self.queue or self.current_batch or self.cancelled_requests def step(self) -> GenerationLoopWorkerOutput: logger.debug("Starting new inference step.") @@ -130,12 +138,17 @@ def step(self) -> GenerationLoopWorkerOutput: self._remove_request_from_batch(state.request_id) for state in self.cancelled_requests: + err = None + if state.validation_err: + err = state.validation_err + outputs.append( SequenceGenerationOutput( # TODO: support multi-sequence id=SequenceId(state.request_id, 0), new_tokens=[], finish_reason=FinishReason.Cancelled, + error = err ) ) if state.request_id in self.current_batch: