-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to support input validation to match OpenAI behavior. #65
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,13 +7,15 @@ | |
from collections import deque | ||
from dataclasses import dataclass | ||
from threading import Condition, Lock, Thread | ||
from typing import Callable, Optional, Union | ||
from typing import Callable, Optional, Union, Any | ||
|
||
from .base import FinishReason, RequestId, RequestState | ||
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId | ||
import structlog | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
LOG = structlog.stdlib.get_logger(__name__) | ||
|
||
@dataclass | ||
class ShutdownCommand: | ||
|
@@ -77,7 +79,16 @@ def __init__( | |
|
||
def add(self, request_states: list[RequestState]): | ||
with self.queue_lock: | ||
self.queue.extend(request_states) | ||
# States which have been invalidated should never be added, directly | ||
# cancel them instead. | ||
valid_states = [] | ||
for request_state in request_states: | ||
if request_state.validation_err is not None: | ||
self.cancelled_requests.append(request_state) | ||
else: | ||
valid_states.append(request_state) | ||
LOG.info("valid_states", valid_states=valid_states) | ||
self.queue.extend(valid_states) | ||
self.has_new_requests.notify_all() | ||
|
||
def cancel(self, request_id: RequestId): | ||
|
@@ -102,7 +113,7 @@ def wait_for_request(self, timeout_seconds=None) -> bool: | |
) | ||
|
||
def has_pending_requests(self) -> bool: | ||
return self.queue or self.current_batch | ||
return self.queue or self.current_batch or self.cancelled_requests | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be why cancelation was taking a while, we wont' actually step forward to cancel requests if we don't submit new work to the queue or the current batch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. |
||
|
||
def step(self) -> GenerationLoopWorkerOutput: | ||
logger.debug("Starting new inference step.") | ||
|
@@ -130,12 +141,17 @@ def step(self) -> GenerationLoopWorkerOutput: | |
self._remove_request_from_batch(state.request_id) | ||
|
||
for state in self.cancelled_requests: | ||
err = None | ||
if state.validation_err: | ||
err = state.validation_err | ||
|
||
outputs.append( | ||
SequenceGenerationOutput( | ||
# TODO: support multi-sequence | ||
id=SequenceId(state.request_id, 0), | ||
new_tokens=[], | ||
finish_reason=FinishReason.Cancelled, | ||
error = err | ||
) | ||
) | ||
if state.request_id in self.current_batch: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value is never
Optional
, is it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes this is a piece of bit rot from my previous attempt. Let me change that.