7272logger = logging .getLogger (__name__ )
7373
7474
75+ StreamEventData = Union [MessageDeltaChunk , ThreadMessage , ThreadRun , RunStep , None ]
76+
7577def _filter_parameters (model_class : Type , parameters : Dict [str , Any ]) -> Dict [str , Any ]:
7678 """
7779 Remove the parameters, non present in class public fields; return shallow copy of a dictionary.
@@ -94,7 +96,7 @@ def _filter_parameters(model_class: Type, parameters: Dict[str, Any]) -> Dict[st
9496 return new_params
9597
9698
97- def _safe_instantiate (model_class : Type , parameters : Union [str , Dict [str , Any ]]) -> Any :
99+ def _safe_instantiate (model_class : Type , parameters : Union [str , Dict [str , Any ]]) -> Union [ str , StreamEventData ] :
98100 """
99101 Instantiate class with the set of parameters from the server.
100102
@@ -104,7 +106,7 @@ def _safe_instantiate(model_class: Type, parameters: Union[str, Dict[str, Any]])
104106 """
105107 if not isinstance (parameters , dict ):
106108 return parameters
107- return model_class (** _filter_parameters (model_class , parameters ))
109+ return cast ( StreamEventData , model_class (** _filter_parameters (model_class , parameters ) ))
108110
109111
110112class ConnectionProperties :
@@ -928,10 +930,7 @@ async def on_unhandled_event(self, event_type: str, event_data: Any) -> None:
928930 """Handle any unhandled event types."""
929931
930932
931- StreamEventData = Union [MessageDeltaChunk , ThreadMessage , ThreadRun , RunStep , None ]
932-
933-
934- class AsyncAgentRunStream (AsyncIterator [Tuple [str , StreamEventData ]]):
933+ class AsyncAgentRunStream (AsyncIterator [Tuple [str , Union [str , StreamEventData ]]]):
935934 def __init__ (
936935 self ,
937936 response_iterator : AsyncIterator [bytes ],
@@ -957,7 +956,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
957956 def __aiter__ (self ):
958957 return self
959958
960- async def __anext__ (self ) -> Tuple [str , StreamEventData ]:
959+ async def __anext__ (self ) -> Tuple [str , Union [ str , StreamEventData ] ]:
961960 while True :
962961 try :
963962 chunk = await self .response_iterator .__anext__ ()
@@ -973,9 +972,9 @@ async def __anext__(self) -> Tuple[str, StreamEventData]:
973972 event_data_str , self .buffer = self .buffer .split ("\n \n " , 1 )
974973 return await self ._process_event (event_data_str )
975974
976- def _parse_event_data (self , event_data_str : str ) -> Tuple [str , StreamEventData , str ]:
975+ def _parse_event_data (self , event_data_str : str ) -> Tuple [str , Union [ str , StreamEventData ] , str ]:
977976 event_lines = event_data_str .strip ().split ("\n " )
978- event_type = None
977+ event_type : Optional [ str ] = None
979978 event_data = ""
980979 error_string = ""
981980
@@ -1001,44 +1000,44 @@ def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData,
10011000
10021001 # Map to the appropriate class instance
10031002 if event_type in {
1004- AgentStreamEvent .THREAD_RUN_CREATED ,
1005- AgentStreamEvent .THREAD_RUN_QUEUED ,
1006- AgentStreamEvent .THREAD_RUN_IN_PROGRESS ,
1007- AgentStreamEvent .THREAD_RUN_REQUIRES_ACTION ,
1008- AgentStreamEvent .THREAD_RUN_COMPLETED ,
1009- AgentStreamEvent .THREAD_RUN_FAILED ,
1010- AgentStreamEvent .THREAD_RUN_CANCELLING ,
1011- AgentStreamEvent .THREAD_RUN_CANCELLED ,
1012- AgentStreamEvent .THREAD_RUN_EXPIRED ,
1003+ AgentStreamEvent .THREAD_RUN_CREATED . value ,
1004+ AgentStreamEvent .THREAD_RUN_QUEUED . value ,
1005+ AgentStreamEvent .THREAD_RUN_IN_PROGRESS . value ,
1006+ AgentStreamEvent .THREAD_RUN_REQUIRES_ACTION . value ,
1007+ AgentStreamEvent .THREAD_RUN_COMPLETED . value ,
1008+ AgentStreamEvent .THREAD_RUN_FAILED . value ,
1009+ AgentStreamEvent .THREAD_RUN_CANCELLING . value ,
1010+ AgentStreamEvent .THREAD_RUN_CANCELLED . value ,
1011+ AgentStreamEvent .THREAD_RUN_EXPIRED . value ,
10131012 }:
10141013 event_data_obj = _safe_instantiate (ThreadRun , parsed_data )
10151014 elif event_type in {
1016- AgentStreamEvent .THREAD_RUN_STEP_CREATED ,
1017- AgentStreamEvent .THREAD_RUN_STEP_IN_PROGRESS ,
1018- AgentStreamEvent .THREAD_RUN_STEP_COMPLETED ,
1019- AgentStreamEvent .THREAD_RUN_STEP_FAILED ,
1020- AgentStreamEvent .THREAD_RUN_STEP_CANCELLED ,
1021- AgentStreamEvent .THREAD_RUN_STEP_EXPIRED ,
1015+ AgentStreamEvent .THREAD_RUN_STEP_CREATED . value ,
1016+ AgentStreamEvent .THREAD_RUN_STEP_IN_PROGRESS . value ,
1017+ AgentStreamEvent .THREAD_RUN_STEP_COMPLETED . value ,
1018+ AgentStreamEvent .THREAD_RUN_STEP_FAILED . value ,
1019+ AgentStreamEvent .THREAD_RUN_STEP_CANCELLED . value ,
1020+ AgentStreamEvent .THREAD_RUN_STEP_EXPIRED . value ,
10221021 }:
10231022 event_data_obj = _safe_instantiate (RunStep , parsed_data )
10241023 elif event_type in {
1025- AgentStreamEvent .THREAD_MESSAGE_CREATED ,
1026- AgentStreamEvent .THREAD_MESSAGE_IN_PROGRESS ,
1027- AgentStreamEvent .THREAD_MESSAGE_COMPLETED ,
1028- AgentStreamEvent .THREAD_MESSAGE_INCOMPLETE ,
1024+ AgentStreamEvent .THREAD_MESSAGE_CREATED . value ,
1025+ AgentStreamEvent .THREAD_MESSAGE_IN_PROGRESS . value ,
1026+ AgentStreamEvent .THREAD_MESSAGE_COMPLETED . value ,
1027+ AgentStreamEvent .THREAD_MESSAGE_INCOMPLETE . value ,
10291028 }:
10301029 event_data_obj = _safe_instantiate (ThreadMessage , parsed_data )
1031- elif event_type == AgentStreamEvent .THREAD_MESSAGE_DELTA :
1030+ elif event_type == AgentStreamEvent .THREAD_MESSAGE_DELTA . value :
10321031 event_data_obj = _safe_instantiate (MessageDeltaChunk , parsed_data )
1033- elif event_type == AgentStreamEvent .THREAD_RUN_STEP_DELTA :
1032+ elif event_type == AgentStreamEvent .THREAD_RUN_STEP_DELTA . value :
10341033 event_data_obj = _safe_instantiate (RunStepDeltaChunk , parsed_data )
10351034 else :
1036- event_data_obj = parsed_data
1035+ event_data_obj = ""
10371036 error_string = str (parsed_data )
10381037
10391038 return event_type , event_data_obj , error_string
10401039
1041- async def _process_event (self , event_data_str : str ) -> Tuple [str , StreamEventData ]:
1040+ async def _process_event (self , event_data_str : str ) -> Tuple [str , Union [ str , StreamEventData ] ]:
10421041 event_type , event_data_obj , error_string = self ._parse_event_data (event_data_str )
10431042
10441043 if (
@@ -1082,7 +1081,7 @@ async def until_done(self) -> None:
10821081 pass
10831082
10841083
1085- class AgentRunStream (Iterator [Tuple [str , StreamEventData ]]):
1084+ class AgentRunStream (Iterator [Tuple [str , Union [ str , StreamEventData ] ]]):
10861085 def __init__ (
10871086 self ,
10881087 response_iterator : Iterator [bytes ],
@@ -1106,7 +1105,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
11061105 def __iter__ (self ):
11071106 return self
11081107
1109- def __next__ (self ) -> Tuple [str , StreamEventData ]:
1108+ def __next__ (self ) -> Tuple [str , Union [ str , StreamEventData ] ]:
11101109 if self .done :
11111110 raise StopIteration
11121111 while True :
@@ -1124,7 +1123,7 @@ def __next__(self) -> Tuple[str, StreamEventData]:
11241123 event_data_str , self .buffer = self .buffer .split ("\n \n " , 1 )
11251124 return self ._process_event (event_data_str )
11261125
1127- def _parse_event_data (self , event_data_str : str ) -> Tuple [str , StreamEventData , str ]:
1126+ def _parse_event_data (self , event_data_str : str ) -> Tuple [str , Union [ str , StreamEventData ] , str ]:
11281127 event_lines = event_data_str .strip ().split ("\n " )
11291128 event_type = None
11301129 event_data = ""
@@ -1150,44 +1149,44 @@ def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData,
11501149
11511150 # Map to the appropriate class instance
11521151 if event_type in {
1153- AgentStreamEvent .THREAD_RUN_CREATED ,
1154- AgentStreamEvent .THREAD_RUN_QUEUED ,
1155- AgentStreamEvent .THREAD_RUN_IN_PROGRESS ,
1156- AgentStreamEvent .THREAD_RUN_REQUIRES_ACTION ,
1157- AgentStreamEvent .THREAD_RUN_COMPLETED ,
1158- AgentStreamEvent .THREAD_RUN_FAILED ,
1159- AgentStreamEvent .THREAD_RUN_CANCELLING ,
1160- AgentStreamEvent .THREAD_RUN_CANCELLED ,
1161- AgentStreamEvent .THREAD_RUN_EXPIRED ,
1152+ AgentStreamEvent .THREAD_RUN_CREATED . value ,
1153+ AgentStreamEvent .THREAD_RUN_QUEUED . value ,
1154+ AgentStreamEvent .THREAD_RUN_IN_PROGRESS . value ,
1155+ AgentStreamEvent .THREAD_RUN_REQUIRES_ACTION . value ,
1156+ AgentStreamEvent .THREAD_RUN_COMPLETED . value ,
1157+ AgentStreamEvent .THREAD_RUN_FAILED . value ,
1158+ AgentStreamEvent .THREAD_RUN_CANCELLING . value ,
1159+ AgentStreamEvent .THREAD_RUN_CANCELLED . value ,
1160+ AgentStreamEvent .THREAD_RUN_EXPIRED . value ,
11621161 }:
11631162 event_data_obj = _safe_instantiate (ThreadRun , parsed_data )
11641163 elif event_type in {
1165- AgentStreamEvent .THREAD_RUN_STEP_CREATED ,
1166- AgentStreamEvent .THREAD_RUN_STEP_IN_PROGRESS ,
1167- AgentStreamEvent .THREAD_RUN_STEP_COMPLETED ,
1168- AgentStreamEvent .THREAD_RUN_STEP_FAILED ,
1169- AgentStreamEvent .THREAD_RUN_STEP_CANCELLED ,
1170- AgentStreamEvent .THREAD_RUN_STEP_EXPIRED ,
1164+ AgentStreamEvent .THREAD_RUN_STEP_CREATED . value ,
1165+ AgentStreamEvent .THREAD_RUN_STEP_IN_PROGRESS . value ,
1166+ AgentStreamEvent .THREAD_RUN_STEP_COMPLETED . value ,
1167+ AgentStreamEvent .THREAD_RUN_STEP_FAILED . value ,
1168+ AgentStreamEvent .THREAD_RUN_STEP_CANCELLED . value ,
1169+ AgentStreamEvent .THREAD_RUN_STEP_EXPIRED . value ,
11711170 }:
11721171 event_data_obj = _safe_instantiate (RunStep , parsed_data )
11731172 elif event_type in {
1174- AgentStreamEvent .THREAD_MESSAGE_CREATED ,
1175- AgentStreamEvent .THREAD_MESSAGE_IN_PROGRESS ,
1176- AgentStreamEvent .THREAD_MESSAGE_COMPLETED ,
1177- AgentStreamEvent .THREAD_MESSAGE_INCOMPLETE ,
1173+ AgentStreamEvent .THREAD_MESSAGE_CREATED . value ,
1174+ AgentStreamEvent .THREAD_MESSAGE_IN_PROGRESS . value ,
1175+ AgentStreamEvent .THREAD_MESSAGE_COMPLETED . value ,
1176+ AgentStreamEvent .THREAD_MESSAGE_INCOMPLETE . value ,
11781177 }:
11791178 event_data_obj = _safe_instantiate (ThreadMessage , parsed_data )
1180- elif event_type == AgentStreamEvent .THREAD_MESSAGE_DELTA :
1179+ elif event_type == AgentStreamEvent .THREAD_MESSAGE_DELTA . value :
11811180 event_data_obj = _safe_instantiate (MessageDeltaChunk , parsed_data )
1182- elif event_type == AgentStreamEvent .THREAD_RUN_STEP_DELTA :
1181+ elif event_type == AgentStreamEvent .THREAD_RUN_STEP_DELTA . value :
11831182 event_data_obj = _safe_instantiate (RunStepDeltaChunk , parsed_data )
11841183 else :
1185- event_data_obj = parsed_data
1184+ event_data_obj = ""
11861185 error_string = str (parsed_data )
11871186
11881187 return event_type , event_data_obj , error_string
11891188
1190- def _process_event (self , event_data_str : str ) -> Tuple [str , StreamEventData ]:
1189+ def _process_event (self , event_data_str : str ) -> Tuple [str , Union [ str , StreamEventData ] ]:
11911190 event_type , event_data_obj , error_string = self ._parse_event_data (event_data_str )
11921191
11931192 if (
0 commit comments