diff --git a/examples/lang/chat_stream_async.py b/examples/lang/chat_stream_async.py index dd37c31..e8c1c15 100644 --- a/examples/lang/chat_stream_async.py +++ b/examples/lang/chat_stream_async.py @@ -6,15 +6,13 @@ from glide import AsyncGlideClient from glide.lang.schemas import ( - ChatStreamError, - StreamChatRequest, + ChatStreamRequest, ChatMessage, - StreamResponse, - ChatStreamChunk, + ChatStreamMessage, ) router_id: str = "default" # defined in Glide config (see glide.config.yaml) -question = "What are the kosher species?" +question = "What is the most complicated theory discovered by humanity?" async def chat_stream() -> None: @@ -23,44 +21,58 @@ async def chat_stream() -> None: print(f"💬Question: {question}") print("💬Answer: ", end="") - last_chunk: Optional[StreamResponse] = None - chat_req = StreamChatRequest(message=ChatMessage(role="user", content=question)) + last_msg: Optional[ChatStreamMessage] = None + chat_req = ChatStreamRequest(message=ChatMessage(role="user", content=question)) started_at = time.perf_counter() first_chunk_recv_at: Optional[float] = None async with glide_client.lang.stream_client(router_id) as client: - async for chunk in client.chat_stream(chat_req): - if not first_chunk_recv_at: - first_chunk_recv_at = time.perf_counter() - - if isinstance(chunk, ChatStreamError): - print(f"💥err: {chunk.message} (code: {chunk.err_code})") - continue - - print(chunk.model_response.message.content, end="") - last_chunk = chunk - - if last_chunk: - if isinstance(last_chunk, ChatStreamChunk): - if reason := last_chunk.model_response.finish_reason: - print(f"\n✅ Generation is done (reason: {reason.value})") - - if isinstance(last_chunk, ChatStreamError): - print(f"\n💥 Generation ended up with error (reason: {last_chunk.message})") - - first_chunk_duration_ms: float = 0 - - if first_chunk_recv_at: - first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 - print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") - - chat_duration_ms = (time.perf_counter() - started_at) * 1_000 - - print( - f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " - f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" - ) + try: + async for message in client.chat_stream(chat_req): + if not first_chunk_recv_at: + first_chunk_recv_at = time.perf_counter() + + last_msg = message + + if message.chunk: + print(message.content_chunk, end="", flush=True) + continue + + if err := message.error: + print(f"💥ERR: {err.message} (code: {err.err_code})") + print("🧹 Restarting the stream") + continue + + print(f"😮Unknown message type: {message}") + except Exception as e: + print(f"💥Stream interrupted by ERR: {e}") + + if last_msg and last_msg.chunk and last_msg.finish_reason: + # LLM gen context + provider_name = last_msg.chunk.provider_name + model_name = last_msg.chunk.model_name + finish_reason = last_msg.finish_reason + + print( + f"\n\n✅ Generation is done " + f"(provider: {provider_name}, model: {model_name}, reason: {finish_reason.value})" + ) + + print( + f"👀Glide Context (router_id: {last_msg.router_id}, model_id: {last_msg.chunk.model_id})" + ) + + if first_chunk_recv_at: + first_chunk_duration_ms = (first_chunk_recv_at - started_at) * 1_000 + print(f"\n⏱️First Response Chunk: {first_chunk_duration_ms:.2f}ms") + + chat_duration_ms = (time.perf_counter() - started_at) * 1_000 + + print( + f"⏱️Chat Duration: {chat_duration_ms:.2f}ms " + f"({(chat_duration_ms - first_chunk_duration_ms):.2f}ms after the first chunk)" + ) if __name__ == "__main__": diff --git a/src/glide/exceptions.py b/src/glide/exceptions.py index bd09682..3c7d435 100644 --- a/src/glide/exceptions.py +++ b/src/glide/exceptions.py @@ -2,19 +2,34 @@ # SPDX-License-Identifier: APACHE-2.0 -class GlideUnavailable(Exception): +class GlideError(Exception): + """The base exception for all Glide server errors""" + + +class GlideUnavailable(GlideError): """ Occurs when Glide API is not available """ -class GlideClientError(Exception): +class GlideClientError(GlideError): """ Occurs when there is an issue with sending a Glide request """ -class GlideClientMismatch(Exception): +class GlideClientMismatch(GlideError): """ Occurs when there is a sign of possible compatibility issues between Glide API and the client version """ + + +class GlideChatStreamError(GlideError): + """ + Occurs when chat stream ends with an error + """ + + def __init__(self, message: str, err_code: str) -> None: + super().__init__(message) + + self.err_code = err_code diff --git a/src/glide/lang/router_async.py b/src/glide/lang/router_async.py index 488a7ef..1d1b15b 100644 --- a/src/glide/lang/router_async.py +++ b/src/glide/lang/router_async.py @@ -13,9 +13,15 @@ from websockets import WebSocketClientProtocol -from glide.exceptions import GlideUnavailable, GlideClientError, GlideClientMismatch +from glide.exceptions import ( + GlideUnavailable, + GlideClientError, + GlideClientMismatch, + GlideChatStreamError, +) from glide.lang import schemas -from glide.lang.schemas import StreamChatRequest, StreamResponse, ChatRequestId +from glide.lang.schemas import ChatStreamRequest, ChatStreamMessage, ChatRequestId +from glide.logging import logger from glide.typing import RouterId @@ -42,9 +48,11 @@ def __init__( self._handlers = handlers - self.requests: asyncio.Queue[StreamChatRequest] = asyncio.Queue() - self.response_chunks: asyncio.Queue[StreamResponse] = asyncio.Queue() - self._response_streams: Dict[ChatRequestId, asyncio.Queue[StreamResponse]] = {} + self.requests: asyncio.Queue[ChatStreamRequest] = asyncio.Queue() + self.response_chunks: asyncio.Queue[ChatStreamMessage] = asyncio.Queue() + self._response_streams: Dict[ + ChatRequestId, asyncio.Queue[ChatStreamMessage] + ] = {} self._sender_task: Optional[asyncio.Task] = None self._receiver_task: Optional[asyncio.Task] = None @@ -54,27 +62,36 @@ def __init__( self._ping_interval = ping_interval self._close_timeout = close_timeout - def request_chat(self, chat_request: StreamChatRequest) -> None: + def request_chat(self, chat_request: ChatStreamRequest) -> None: self.requests.put_nowait(chat_request) async def chat_stream( - self, req: StreamChatRequest - ) -> AsyncGenerator[StreamResponse, None]: - chunk_buffer: asyncio.Queue[StreamResponse] = asyncio.Queue() - self._response_streams[req.id] = chunk_buffer + self, + req: ChatStreamRequest, + # TODO: add timeout + ) -> AsyncGenerator[ChatStreamMessage, None]: + msg_buffer: asyncio.Queue[ChatStreamMessage] = asyncio.Queue() + self._response_streams[req.id] = msg_buffer self.request_chat(req) - while True: - chunk = await chunk_buffer.get() - - yield chunk + try: + while True: + message = await msg_buffer.get() + + if err := message.ended_with_err: + # fail only on fatal errors that indicate stream stop + raise GlideChatStreamError( + f"Chat stream {req.id} ended with an error: {err.message} (code: {err.err_code})", + err.err_code, + ) - # TODO: handle stream end on error - if chunk.model_response.finish_reason: - break + yield message # returns content chunk and some error messages - self._response_streams.pop(req.id, None) + if message.finish_reason: + break + finally: + self._response_streams.pop(req.id, None) async def start(self) -> None: self._ws_client = await websockets.connect( @@ -90,39 +107,41 @@ async def start(self) -> None: self._receiver_task = asyncio.create_task(self._receiver()) async def _sender(self) -> None: - try: - while self._ws_client and self._ws_client.open: + while self._ws_client and self._ws_client.open: + try: chat_request = await self.requests.get() await self._ws_client.send(chat_request.json()) - except asyncio.CancelledError: - # TODO: log - ... + except asyncio.CancelledError: + # TODO: log + break async def _receiver(self) -> None: - try: - while self._ws_client and self._ws_client.open: - try: - raw_chunk = await self._ws_client.recv() - chunk: StreamResponse = pydantic.parse_obj_as( - StreamResponse, - json.loads(raw_chunk), - ) - - if chunk_buffer := self._response_streams.get(chunk.id): - chunk_buffer.put_nowait(chunk) - continue - - self.response_chunks.put_nowait(chunk) - except pydantic.ValidationError as e: - raise GlideClientMismatch( - "Failed to validate Glide API response. " - "Please make sure Glide API and client versions are compatible" - ) from e - except asyncio.CancelledError: - ... + while self._ws_client and self._ws_client.open: + try: + raw_chunk = await self._ws_client.recv() + message: ChatStreamMessage = ChatStreamMessage(**json.loads(raw_chunk)) + + logger.debug("received chat stream message", extra={"message": message}) + + if msg_buffer := self._response_streams.get(message.id): + msg_buffer.put_nowait(message) + continue + + self.response_chunks.put_nowait(message) + except pydantic.ValidationError: + logger.error( + "Failed to validate Glide API response. " + "Please make sure Glide API and client versions are compatible", + exc_info=True, + ) + except asyncio.CancelledError: + break + except Exception as e: + logger.exception(e) async def stop(self) -> None: + # TODO: allow to timeout shutdown too if self._sender_task: self._sender_task.cancel() await self._sender_task diff --git a/src/glide/lang/schemas.py b/src/glide/lang/schemas.py index 93cbbc6..e32a632 100644 --- a/src/glide/lang/schemas.py +++ b/src/glide/lang/schemas.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from enum import Enum -from typing import List, Optional, Dict, Any, Union +from typing import List, Optional, Dict, Any from pydantic import Field @@ -18,7 +18,10 @@ class FinishReason(str, Enum): # generation is finished successfully without interruptions COMPLETE = "complete" # generation is interrupted because of the length of the response text - LENGTH = "length" + MAX_TOKENS = "max_tokens" + CONTENT_FILTERED = "content_filtered" + ERROR = "error" + OTHER = "other" class LangRouter(Schema): ... @@ -67,7 +70,7 @@ class ChatResponse(Schema): model_response: ModelResponse -class StreamChatRequest(Schema): +class ChatStreamRequest(Schema): id: ChatRequestId = Field(default_factory=lambda: str(uuid.uuid4())) message: ChatMessage message_history: List[ChatMessage] = Field(default_factory=list) @@ -78,7 +81,6 @@ class StreamChatRequest(Schema): class ModelChunkResponse(Schema): metadata: Optional[Metadata] = None message: ChatMessage - finish_reason: Optional[FinishReason] = None class ChatStreamChunk(Schema): @@ -86,23 +88,59 @@ class ChatStreamChunk(Schema): A response chunk of a streaming chat """ - id: ChatRequestId - # TODO: should be required, needs to fix on the Glide side - created: Optional[datetime] = None - provider: Optional[ProviderName] = None - router: Optional[RouterId] = None - model: Optional[ModelName] = None - model_id: str - metadata: Optional[Metadata] = None + + provider_name: ProviderName + model_name: ModelName + model_response: ModelChunkResponse + finish_reason: Optional[FinishReason] = None class ChatStreamError(Schema): id: ChatRequestId err_code: str message: str + finish_reason: Optional[FinishReason] = None + + +class ChatStreamMessage(Schema): + id: ChatRequestId + created_at: datetime metadata: Optional[Metadata] = None + router_id: RouterId + + chunk: Optional[ChatStreamChunk] = None + error: Optional[ChatStreamError] = None + + @property + def finish_reason(self) -> Optional[FinishReason]: + if self.chunk and self.chunk.finish_reason: + return self.chunk.finish_reason + + if self.error and self.error.finish_reason: + return self.error.finish_reason + + return None + + @property + def ended_with_err(self) -> Optional[ChatStreamError]: + if self.error and self.error.finish_reason: + return self.error + + return None + + @property + def content_chunk(self) -> Optional[str]: + """ + Returns received text generation chunk. + + Be careful with using this method to see if there is a chunk (rather than an error), + because content can be an empty string with some providers like OpenAI. + Better check for `self.chunk` in that case. + """ + if not self.chunk: + return None -StreamResponse = Union[ChatStreamChunk, ChatStreamError] + return self.chunk.model_response.message.content