Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support input validation to match OpenAI behavior. #65

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions serve/mlc_serve/engine/async_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import AsyncIterator
from typing import AsyncIterator, Any

from .base import (
InferenceEngine,
Expand All @@ -15,7 +15,9 @@


class TextGenerationError(Exception):
pass
def __init__(self, error: Any) -> None:
self.error = error
super(self).__init__(error)


class AsyncEngineConnector:
Expand Down Expand Up @@ -98,6 +100,7 @@ async def _get_queue_item_until_stopped(self, queue: ResultQueue) -> RequestOutp

if wait_shutdown_task.done():
if self.engine_loop_exception is not None:
import pdb; pdb.set_trace()
raise RuntimeError(
f"InferenceEngine raised exception: {self.engine_loop_exception}"
)
Expand Down
22 changes: 18 additions & 4 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional, List
from typing import List, Callable, Any, Optional
import json
import inspect
from .sampling_params import SamplingParams, SamplingType
Expand Down Expand Up @@ -41,10 +43,10 @@ def get_engine_config(dict_config, enable_check = True):
# assert engine_config.max_num_sequences > 0
# assert engine_config.max_num_sequences * engine_config.max_input_len == engine_config.max_num_batched_tokens

assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps
assert engine_config.prompt_allocate_ratio > 0

return engine_config

@dataclass
Expand Down Expand Up @@ -75,12 +77,23 @@ class FinishReason(Enum):
Length = "length"
Cancelled = "cancelled"

# A single token.
Token = List[int]

@dataclass
class ValidationError:
msg: str

# The type signature of the token validation callback.
ValidateTokensCallback = Callable[["Request", List[Token]], ValidationError]

@dataclass
class Request:
request_id: RequestId
messages: list[ChatMessage]

# Perform request validation post-tokenization, used by the HTTP layer to control validation.
validate_tokens: Optional[ValidateTokensCallback] = None
# Number of sequences to generate
num_sequences: int = 1
# TODO: should `best_of` be handled in the serving layer?
Expand Down Expand Up @@ -209,6 +222,7 @@ class RequestState:
stopping_criteria: StoppingCriteria
debug_options: DebugOptions
is_ended: bool = False
validation_err: Optional[ValidationError] = None

def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
if stopping_criteria.stop_sequences:
Expand All @@ -226,4 +240,4 @@ def check_stopping_sequences(stopping_criteria, output_text, delta, is_ended):
output_text = output_text[:output_text.find(t) + len(t)]
is_ended = True
break
return output_text, delta, is_ended
return output_text, delta, is_ended
11 changes: 9 additions & 2 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import multiprocessing
import queue
from threading import Lock
from typing import Callable
from typing import Callable, Optional

from .base import (
InferenceStepResult,
Expand Down Expand Up @@ -84,6 +84,8 @@ def add(self, requests: list[Request]):
# TODO: verify that request id is unique
if req.num_sequences > 1:
raise RuntimeError("num_sequences > 1 is not supported for now")

# If the request violates the tokenization, this returns None, so skip.
state = self._get_new_request_state(req)
new_request_states.append(state)

Expand Down Expand Up @@ -192,14 +194,18 @@ def step(self) -> InferenceStepResult:
def _is_ready_to_serve(self) -> bool:
return self.worker_process is not None and self.worker_process.is_alive()

def _get_new_request_state(self, request: Request) -> RequestState:
def _get_new_request_state(self, request: Request) -> Optional[RequestState]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value is never Optional, is it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes this is a piece of bit rot from my previous attempt. Let me change that.

if request.debug_options.prompt is not None:
prompt = request.debug_options.prompt
else:
prompt = self.conversation_template.apply(request.messages)

prompt_tokens = self.tokenizer.encode(prompt)

validation_err = None
if request.validate_tokens is not None:
validation_err = request.validate_tokens(request, prompt_tokens)

return RequestState(
request_id=request.request_id,
token_ids=prompt_tokens,
Expand All @@ -209,6 +215,7 @@ def _get_new_request_state(self, request: Request) -> RequestState:
stopping_criteria=request.stopping_criteria,
debug_options=request.debug_options,
output_text="",
validation_err=validation_err,
)

def _decode_last_output(self, state: RequestState) -> str:
Expand Down
22 changes: 19 additions & 3 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from collections import deque
from dataclasses import dataclass
from threading import Condition, Lock, Thread
from typing import Callable, Optional, Union
from typing import Callable, Optional, Union, Any

from .base import FinishReason, RequestId, RequestState
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId
import structlog

logger = logging.getLogger(__name__)

LOG = structlog.stdlib.get_logger(__name__)

@dataclass
class ShutdownCommand:
Expand Down Expand Up @@ -77,7 +79,16 @@ def __init__(

def add(self, request_states: list[RequestState]):
with self.queue_lock:
self.queue.extend(request_states)
# States which have been invalidated should never be added, directly
# cancel them instead.
valid_states = []
for request_state in request_states:
if request_state.validation_err is not None:
self.cancelled_requests.append(request_state)
else:
valid_states.append(request_state)
LOG.info("valid_states", valid_states=valid_states)
self.queue.extend(valid_states)
self.has_new_requests.notify_all()

def cancel(self, request_id: RequestId):
Expand All @@ -102,7 +113,7 @@ def wait_for_request(self, timeout_seconds=None) -> bool:
)

def has_pending_requests(self) -> bool:
return self.queue or self.current_batch
return self.queue or self.current_batch or self.cancelled_requests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be why cancelation was taking a while, we wont' actually step forward to cancel requests if we don't submit new work to the queue or the current batch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.


def step(self) -> GenerationLoopWorkerOutput:
logger.debug("Starting new inference step.")
Expand Down Expand Up @@ -130,12 +141,17 @@ def step(self) -> GenerationLoopWorkerOutput:
self._remove_request_from_batch(state.request_id)

for state in self.cancelled_requests:
err = None
if state.validation_err:
err = state.validation_err

outputs.append(
SequenceGenerationOutput(
# TODO: support multi-sequence
id=SequenceId(state.request_id, 0),
new_tokens=[],
finish_reason=FinishReason.Cancelled,
error = err
)
)
if state.request_id in self.current_batch:
Expand Down