Skip to content

Commit

Permalink
Changes to support input validation to match OpenAI behavior. (#65)
Browse files Browse the repository at this point in the history
* Add the ability to control token validation

* Remove debugging

* Fix

* Address issue with staging engine
  • Loading branch information
jroesch authored Nov 16, 2023
1 parent f8609cd commit 43c8ebb
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 14 deletions.
6 changes: 4 additions & 2 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import AsyncIterator
from typing import AsyncIterator, Any

from .base import (
InferenceEngine,
Expand All @@ -15,7 +15,9 @@


class TextGenerationError(Exception):
pass
def __init__(self, error: Any) -> None:
self.error = error
super().__init__(error)


class AsyncEngineConnector:
Expand Down
33 changes: 26 additions & 7 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List
from typing import List, Callable, Any, Optional
import json
import inspect
from .sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -41,10 +43,10 @@ def get_engine_config(dict_config, enable_check = True):
# assert engine_config.max_num_sequences > 0
# assert engine_config.max_num_sequences * engine_config.max_input_len == engine_config.max_num_batched_tokens

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps
assert engine_config.prompt_allocate_ratio > 0

return engine_config

@dataclass
Expand Down Expand Up @@ -75,20 +77,32 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"

# A single token.
Token = List[int]

@dataclass
class ValidationError:
msg: str

# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]

@dataclass
class Request:
request_id: RequestId
messages: list[ChatMessage]

# Number of sequences to generate
num_sequences: int = 1
# TODO: should `best_of` be handled in the serving layer?
best_of: int = None

# Options for sampling.
sampling_params: SamplingParams = field(default_factory=SamplingParams)
# Options for stopping.
stopping_criteria: StoppingCriteria = field(default_factory=StoppingCriteria)
# Options for debugging.
debug_options: DebugOptions = field(default_factory=DebugOptions)
# Perform request validation post-tokenization, used by the HTTP layer to control validation.
validate_tokens: Optional[ValidateTokensCallback] = None

def __post_init__(self):
if self.best_of is None:
Expand Down Expand Up @@ -132,7 +146,11 @@ class RequestOutput:
# TODO: reconsider the place to put this number
# Only set for outputs with valid sequence otuputs
num_prompt_tokens: Optional[int] = None

# TODO(@jroesch): We should generalize the type here so we are allowed to return more structured information
# for logging/user output.
#
# Right now I am abusing dynamic typing by putting the ValidationError in here.
# I would prefer to unblock ourselves then figure this one out right now
error: Optional[str] = None

@property
Expand Down Expand Up @@ -209,6 +227,7 @@ class RequestState:
stopping_criteria: StoppingCriteria
debug_options: DebugOptions
is_ended: bool = False
validation_err: Optional[ValidationError] = None

def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
Expand All @@ -226,4 +245,4 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
output_text = output_text[:output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
return output_text, delta, is_ended
9 changes: 8 additions & 1 deletion serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable
from typing import Callable, Optional

from .base import (
InferenceStepResult,
Expand Down Expand Up @@ -84,6 +84,8 @@ def add(self, requests: list[Request]):
# TODO: verify that request id is unique
if req.num_sequences > 1:
raise RuntimeError("num_sequences > 1 is not supported for now")

# If the request violates the tokenization, this returns None, so skip.
state = self._get_new_request_state(req)
new_request_states.append(state)

Expand Down Expand Up @@ -200,6 +202,10 @@ def _get_new_request_state(self, request: Request) -> RequestState:

prompt_tokens = self.tokenizer.encode(prompt)

validation_err = None
if request.validate_tokens is not None:
validation_err = request.validate_tokens(request, prompt_tokens)

return RequestState(
request_id=request.request_id,
token_ids=prompt_tokens,
Expand All @@ -209,6 +215,7 @@ def _get_new_request_state(self, request: Request) -> RequestState:
stopping_criteria=request.stopping_criteria,
debug_options=request.debug_options,
output_text="",
validation_err=validation_err,
)

def _decode_last_output(self, state: RequestState) -> str:
Expand Down
21 changes: 17 additions & 4 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from collections import deque
from dataclasses import dataclass
from threading import Condition, Lock, Thread
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Any

from .base import FinishReason, RequestId, RequestState
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId
import structlog

logger = logging.getLogger(__name__)


@dataclass
class ShutdownCommand:
pass
Expand Down Expand Up @@ -77,7 +77,15 @@ def __init__(

def add(self, request_states: list[RequestState]):
with self.queue_lock:
self.queue.extend(request_states)
# States which have been invalidated should never be added, directly
# cancel them instead.
valid_states = []
for request_state in request_states:
if request_state.validation_err is not None:
self.cancelled_requests.append(request_state)
else:
valid_states.append(request_state)
self.queue.extend(valid_states)
self.has_new_requests.notify_all()

def cancel(self, request_id: RequestId):
Expand All @@ -102,7 +110,7 @@ def wait_for_request(self, timeout_seconds=None) -> bool:
)

def has_pending_requests(self) -> bool:
return self.queue or self.current_batch
return self.queue or self.current_batch or self.cancelled_requests

def step(self) -> GenerationLoopWorkerOutput:
logger.debug("Starting new inference step.")
Expand Down Expand Up @@ -130,12 +138,17 @@ def step(self) -> GenerationLoopWorkerOutput:
self._remove_request_from_batch(state.request_id)

for state in self.cancelled_requests:
err = None
if state.validation_err:
err = state.validation_err

outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error = err
)
)
if state.request_id in self.current_batch:
Expand Down

0 comments on commit 43c8ebb

Please sign in to comment.