diff --git a/portkey_ai/api_resources/apis/threads.py b/portkey_ai/api_resources/apis/threads.py index 5f82683..b1b46dd 100644 --- a/portkey_ai/api_resources/apis/threads.py +++ b/portkey_ai/api_resources/apis/threads.py @@ -1,7 +1,10 @@ import json -from typing import Any, Iterable, Literal, Optional, Union +from typing import Any, AsyncIterator, Iterable, Iterator, Literal, Optional, Union import typing +from portkey_ai._vendor.openai.types.beta.assistant_stream_event import ( + AssistantStreamEvent, +) from portkey_ai.api_resources.apis.api_resource import APIResource, AsyncAPIResource from portkey_ai.api_resources.client import AsyncPortkey, Portkey from portkey_ai.api_resources.types.thread_message_type import ( @@ -105,7 +108,25 @@ def delete( return data - def create_and_run(self, assistant_id, **kwargs) -> Run: + def stream_create_and_run( # type: ignore[return] + self, assistant_id, **kwargs + ) -> Union[Run, Iterator[AssistantStreamEvent]]: + with self.openai_client.with_streaming_response.beta.threads.create_and_run( + assistant_id=assistant_id, stream=True, extra_body=kwargs + ) as streaming: + for line in streaming.iter_lines(): + json_string = line.replace("data: ", "") + json_string = json_string.strip().rstrip("\n") + if json_string == "[DONE]": + break + elif json_string == "": + continue + elif json_string != "": + yield json_string + else: + return "" + + def normal_create_and_run(self, assistant_id, **kwargs) -> Run: response = self.openai_client.with_raw_response.beta.threads.create_and_run( assistant_id=assistant_id, extra_body=kwargs ) @@ -113,6 +134,14 @@ def create_and_run(self, assistant_id, **kwargs) -> Run: data._headers = response.headers return data + def create_and_run( + self, assistant_id, stream: Union[bool, NotGiven] = NOT_GIVEN, **kwargs + ) -> Union[Run, Iterator[AssistantStreamEvent]]: + if stream is True: + return self.stream_create_and_run(assistant_id, **kwargs) + else: + return self.normal_create_and_run(assistant_id, **kwargs) + def create_and_run_poll( self, *, @@ -307,15 +336,60 @@ def __init__(self, client: Portkey) -> None: self.openai_client = client.openai_client self.steps = Steps(client) - def create(self, thread_id: str, *, assistant_id: str, **kwargs) -> Run: + def stream_create( # type: ignore[return] + self, + thread_id, + assistant_id, + **kwargs, + ) -> Union[Run, Iterator[AssistantStreamEvent]]: + with self.openai_client.with_streaming_response.beta.threads.runs.create( + thread_id=thread_id, + assistant_id=assistant_id, + stream=True, + extra_body=kwargs, + ) as streaming: + for line in streaming.iter_lines(): + json_string = line.replace("data: ", "") + json_string = json_string.strip().rstrip("\n") + if json_string == "[DONE]": + break + elif json_string == "": + continue + elif json_string != "": + yield json_string + else: + return "" + + def normal_create( + self, + thread_id, + assistant_id, + **kwargs, + ) -> Run: response = self.openai_client.with_raw_response.beta.threads.runs.create( thread_id=thread_id, assistant_id=assistant_id, extra_body=kwargs ) data = Run(**json.loads(response.text)) data._headers = response.headers - return data + def create( + self, + thread_id: str, + *, + assistant_id: str, + stream: Union[bool, NotGiven] = NOT_GIVEN, + **kwargs, + ) -> Union[Run, Iterator[AssistantStreamEvent]]: + if stream is True: + return self.stream_create( + thread_id=thread_id, assistant_id=assistant_id, **kwargs + ) + else: + return self.normal_create( + thread_id=thread_id, assistant_id=assistant_id, **kwargs + ) + def retrieve(self, thread_id, run_id, **kwargs) -> Run: response = self.openai_client.with_raw_response.beta.threads.runs.retrieve( thread_id=thread_id, run_id=run_id, extra_body=kwargs @@ -681,7 +755,25 @@ async def delete( return data - async def create_and_run(self, assistant_id, **kwargs) -> Run: + async def stream_create_and_run( + self, assistant_id, **kwargs + ) -> Union[Run, AsyncIterator[AssistantStreamEvent]]: + async with self.openai_client.with_streaming_response.beta.threads.create_and_run( # noqa: E501 + assistant_id=assistant_id, stream=True, extra_body=kwargs + ) as streaming: + async for line in streaming.iter_lines(): + json_string = line.replace("data: ", "") + json_string = json_string.strip().rstrip("\n") + if json_string == "[DONE]": + break + elif json_string == "": + continue + elif json_string != "": + yield json_string + else: + pass + + async def normal_create_and_run(self, assistant_id, **kwargs) -> Run: response = ( await self.openai_client.with_raw_response.beta.threads.create_and_run( assistant_id=assistant_id, extra_body=kwargs @@ -691,6 +783,14 @@ async def create_and_run(self, assistant_id, **kwargs) -> Run: data._headers = response.headers return data + async def create_and_run( + self, assistant_id, stream: Union[bool, NotGiven] = NOT_GIVEN, **kwargs + ) -> Union[Run, AsyncIterator[AssistantStreamEvent]]: + if stream is True: + return self.stream_create_and_run(assistant_id=assistant_id, **kwargs) + else: + return await self.normal_create_and_run(assistant_id=assistant_id, **kwargs) + async def create_and_run_poll( self, *, @@ -897,15 +997,58 @@ def __init__(self, client: AsyncPortkey) -> None: self.openai_client = client.openai_client self.steps = AsyncSteps(client) - async def create(self, thread_id: str, *, assistant_id: str, **kwargs) -> Run: + async def stream_create( + self, + thread_id, + assistant_id, + **kwargs, + ) -> Union[Run, AsyncIterator[AssistantStreamEvent]]: + async with self.openai_client.with_streaming_response.beta.threads.runs.create( # noqa: E501 + thread_id=thread_id, + assistant_id=assistant_id, + stream=True, + extra_body=kwargs, + ) as response: + async for line in response.iter_lines(): + json_string = line.replace("data: ", "") + json_string = json_string.strip().rstrip("\n") + if json_string == "[DONE]": + break + elif json_string == "": + continue + elif json_string != "": + yield json_string + else: + pass + + async def normal_create( + self, + thread_id, + assistant_id, + **kwargs, + ) -> Run: response = await self.openai_client.with_raw_response.beta.threads.runs.create( - thread_id=thread_id, assistant_id=assistant_id, extra_body=kwargs + thread_id=thread_id, + assistant_id=assistant_id, + extra_body=kwargs, ) data = Run(**json.loads(response.text)) data._headers = response.headers - return data + async def create( + self, + thread_id: str, + *, + assistant_id: str, + stream: Union[bool, NotGiven] = NOT_GIVEN, + **kwargs, + ) -> Union[Run, AsyncIterator[AssistantStreamEvent]]: + if stream is True: + return self.stream_create(thread_id, assistant_id, **kwargs) + else: + return await self.normal_create(thread_id, assistant_id, **kwargs) + async def retrieve(self, thread_id, run_id, **kwargs) -> Run: response = ( await self.openai_client.with_raw_response.beta.threads.runs.retrieve(