Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 #7 Made the stream chunk schema actual #8

Merged
merged 6 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 50 additions & 38 deletions examples/lang/chat_stream_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@

from glide import AsyncGlideClient
from glide.lang.schemas import (
ChatStreamError,
StreamChatRequest,
ChatStreamRequest,
ChatMessage,
StreamResponse,
ChatStreamChunk,
ChatStreamMessage,
)

router_id: str = "default" # defined in Glide config (see glide.config.yaml)
question = "What are the kosher species?"
question = "What is the most complicated theory discovered by humanity?"


async def chat_stream() -> None:
Expand All @@ -23,44 +21,58 @@ async def chat_stream() -> None:
print(f"💬Question: {question}")
print("💬Answer: ", end="")

last_chunk: Optional[StreamResponse] = None
chat_req = StreamChatRequest(message=ChatMessage(role="user", content=question))
last_msg: Optional[ChatStreamMessage] = None
chat_req = ChatStreamRequest(message=ChatMessage(role="user", content=question))

started_at = time.perf_counter()
first_chunk_recv_at: Optional[float] = None

async with glide_client.lang.stream_client(router_id) as client:
async for chunk in client.chat_stream(chat_req):
if not first_chunk_recv_at:
first_chunk_recv_at = time.perf_counter()

if isinstance(chunk, ChatStreamError):
print(f"💥err: {chunk.message} (code: {chunk.err_code})")
continue

print(chunk.model_response.message.content, end="")
last_chunk = chunk

if last_chunk:
if isinstance(last_chunk, ChatStreamChunk):
if reason := last_chunk.model_response.finish_reason:
print(f"\n✅ Generation is done (reason: {reason.value})")

if isinstance(last_chunk, ChatStreamError):
print(f"\n💥 Generation ended up with error (reason: {last_chunk.message})")

first_chunk_duration_ms: float = 0

if first_chunk_recv_at:
first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000
print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms")

chat_duration_ms = (time.perf_counter() - started_at) * 1_000

print(
f"⏱️Chat Duration: {chat_duration_ms:.2f}ms "
f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)"
)
try:
async for message in client.chat_stream(chat_req):
if not first_chunk_recv_at:
first_chunk_recv_at = time.perf_counter()

last_msg = message

if message.chunk:
print(message.content_chunk, end="", flush=True)
continue

if err := message.error:
print(f"💥ERR: {err.message} (code: {err.err_code})")
print("🧹 Restarting the stream")
continue

print(f"😮Unknown message type: {message}")
except Exception as e:
print(f"💥Stream interrupted by ERR: {e}")

if last_msg and last_msg.chunk and last_msg.finish_reason:
# LLM gen context
provider_name = last_msg.chunk.provider_name
model_name = last_msg.chunk.model_name
finish_reason = last_msg.finish_reason

print(
f"\n\n✅ Generation is done "
f"(provider: {provider_name}, model: {model_name}, reason: {finish_reason.value})"
)

print(
f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_msg.chunk.model_id})"
)

if first_chunk_recv_at:
first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000
print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms")

chat_duration_ms = (time.perf_counter() - started_at) * 1_000

print(
f"⏱️Chat Duration: {chat_duration_ms:.2f}ms "
f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)"
)


if __name__ == "__main__":
Expand Down
21 changes: 18 additions & 3 deletions src/glide/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,34 @@
# SPDX-License-Identifier: APACHE-2.0


class GlideUnavailable(Exception):
class GlideError(Exception):
"""The base exception for all Glide server errors"""


class GlideUnavailable(GlideError):
"""
Occurs when Glide API is not available
"""


class GlideClientError(Exception):
class GlideClientError(GlideError):
"""
Occurs when there is an issue with sending a Glide request
"""


class GlideClientMismatch(Exception):
class GlideClientMismatch(GlideError):
"""
Occurs when there is a sign of possible compatibility issues between Glide API and the client version
"""


class GlideChatStreamError(GlideError):
"""
Occurs when chat stream ends with an error
"""

def __init__(self, message: str, err_code: str) -> None:
super().__init__(message)

self.err_code = err_code
107 changes: 63 additions & 44 deletions src/glide/lang/router_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@

from websockets import WebSocketClientProtocol

from glide.exceptions import GlideUnavailable, GlideClientError, GlideClientMismatch
from glide.exceptions import (
GlideUnavailable,
GlideClientError,
GlideClientMismatch,
GlideChatStreamError,
)
from glide.lang import schemas
from glide.lang.schemas import StreamChatRequest, StreamResponse, ChatRequestId
from glide.lang.schemas import ChatStreamRequest, ChatStreamMessage, ChatRequestId
from glide.logging import logger
from glide.typing import RouterId


Expand All @@ -42,9 +48,11 @@ def __init__(

self._handlers = handlers

self.requests: asyncio.Queue[StreamChatRequest] = asyncio.Queue()
self.response_chunks: asyncio.Queue[StreamResponse] = asyncio.Queue()
self._response_streams: Dict[ChatRequestId, asyncio.Queue[StreamResponse]] = {}
self.requests: asyncio.Queue[ChatStreamRequest] = asyncio.Queue()
self.response_chunks: asyncio.Queue[ChatStreamMessage] = asyncio.Queue()
self._response_streams: Dict[
ChatRequestId, asyncio.Queue[ChatStreamMessage]
] = {}

self._sender_task: Optional[asyncio.Task] = None
self._receiver_task: Optional[asyncio.Task] = None
Expand All @@ -54,27 +62,36 @@ def __init__(
self._ping_interval = ping_interval
self._close_timeout = close_timeout

def request_chat(self, chat_request: StreamChatRequest) -> None:
def request_chat(self, chat_request: ChatStreamRequest) -> None:
self.requests.put_nowait(chat_request)

async def chat_stream(
self, req: StreamChatRequest
) -> AsyncGenerator[StreamResponse, None]:
chunk_buffer: asyncio.Queue[StreamResponse] = asyncio.Queue()
self._response_streams[req.id] = chunk_buffer
self,
req: ChatStreamRequest,
# TODO: add timeout
) -> AsyncGenerator[ChatStreamMessage, None]:
msg_buffer: asyncio.Queue[ChatStreamMessage] = asyncio.Queue()
self._response_streams[req.id] = msg_buffer

self.request_chat(req)

while True:
chunk = await chunk_buffer.get()

yield chunk
try:
while True:
message = await msg_buffer.get()

if err := message.ended_with_err:
# fail only on fatal errors that indicate stream stop
raise GlideChatStreamError(
f"Chat stream {req.id} ended with an error: {err.message} (code: {err.err_code})",
err.err_code,
)

# TODO: handle stream end on error
if chunk.model_response.finish_reason:
break
yield message # returns content chunk and some error messages

self._response_streams.pop(req.id, None)
if message.finish_reason:
break
finally:
self._response_streams.pop(req.id, None)

async def start(self) -> None:
self._ws_client = await websockets.connect(
Expand All @@ -90,39 +107,41 @@ async def start(self) -> None:
self._receiver_task = asyncio.create_task(self._receiver())

async def _sender(self) -> None:
try:
while self._ws_client and self._ws_client.open:
while self._ws_client and self._ws_client.open:
try:
chat_request = await self.requests.get()

await self._ws_client.send(chat_request.json())
except asyncio.CancelledError:
# TODO: log
...
except asyncio.CancelledError:
# TODO: log
break

async def _receiver(self) -> None:
try:
while self._ws_client and self._ws_client.open:
try:
raw_chunk = await self._ws_client.recv()
chunk: StreamResponse = pydantic.parse_obj_as(
StreamResponse,
json.loads(raw_chunk),
)

if chunk_buffer := self._response_streams.get(chunk.id):
chunk_buffer.put_nowait(chunk)
continue

self.response_chunks.put_nowait(chunk)
except pydantic.ValidationError as e:
raise GlideClientMismatch(
"Failed to validate Glide API response. "
"Please make sure Glide API and client versions are compatible"
) from e
except asyncio.CancelledError:
...
while self._ws_client and self._ws_client.open:
try:
raw_chunk = await self._ws_client.recv()
message: ChatStreamMessage = ChatStreamMessage(**json.loads(raw_chunk))

logger.debug("received chat stream message", extra={"message": message})

if msg_buffer := self._response_streams.get(message.id):
msg_buffer.put_nowait(message)
continue

self.response_chunks.put_nowait(message)
except pydantic.ValidationError:
logger.error(
"Failed to validate Glide API response. "
"Please make sure Glide API and client versions are compatible",
exc_info=True,
)
except asyncio.CancelledError:
break
except Exception as e:
logger.exception(e)

async def stop(self) -> None:
# TODO: allow to timeout shutdown too
if self._sender_task:
self._sender_task.cancel()
await self._sender_task
Expand Down
Loading