Skip to content

Commit 7a66f7e

Browse files
authored
Fix pyright for agents. (#38464)
* Fix as if python version is 3.8 * Fixed * Add custom config * Move line, ignoring import
1 parent d920d44 commit 7a66f7e

File tree

5 files changed

+80
-64
lines changed

5 files changed

+80
-64
lines changed

sdk/ai/azure-ai-projects/azure/ai/projects/models/_patch.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@
7272
logger = logging.getLogger(__name__)
7373

7474

75+
StreamEventData = Union[MessageDeltaChunk, ThreadMessage, ThreadRun, RunStep, None]
76+
7577
def _filter_parameters(model_class: Type, parameters: Dict[str, Any]) -> Dict[str, Any]:
7678
"""
7779
Remove the parameters, non present in class public fields; return shallow copy of a dictionary.
@@ -94,7 +96,7 @@ def _filter_parameters(model_class: Type, parameters: Dict[str, Any]) -> Dict[st
9496
return new_params
9597

9698

97-
def _safe_instantiate(model_class: Type, parameters: Union[str, Dict[str, Any]]) -> Any:
99+
def _safe_instantiate(model_class: Type, parameters: Union[str, Dict[str, Any]]) -> Union[str, StreamEventData]:
98100
"""
99101
Instantiate class with the set of parameters from the server.
100102
@@ -104,7 +106,7 @@ def _safe_instantiate(model_class: Type, parameters: Union[str, Dict[str, Any]])
104106
"""
105107
if not isinstance(parameters, dict):
106108
return parameters
107-
return model_class(**_filter_parameters(model_class, parameters))
109+
return cast(StreamEventData, model_class(**_filter_parameters(model_class, parameters)))
108110

109111

110112
class ConnectionProperties:
@@ -928,10 +930,7 @@ async def on_unhandled_event(self, event_type: str, event_data: Any) -> None:
928930
"""Handle any unhandled event types."""
929931

930932

931-
StreamEventData = Union[MessageDeltaChunk, ThreadMessage, ThreadRun, RunStep, None]
932-
933-
934-
class AsyncAgentRunStream(AsyncIterator[Tuple[str, StreamEventData]]):
933+
class AsyncAgentRunStream(AsyncIterator[Tuple[str, Union[str, StreamEventData]]]):
935934
def __init__(
936935
self,
937936
response_iterator: AsyncIterator[bytes],
@@ -957,7 +956,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
957956
def __aiter__(self):
958957
return self
959958

960-
async def __anext__(self) -> Tuple[str, StreamEventData]:
959+
async def __anext__(self) -> Tuple[str, Union[str, StreamEventData]]:
961960
while True:
962961
try:
963962
chunk = await self.response_iterator.__anext__()
@@ -973,9 +972,9 @@ async def __anext__(self) -> Tuple[str, StreamEventData]:
973972
event_data_str, self.buffer = self.buffer.split("\n\n", 1)
974973
return await self._process_event(event_data_str)
975974

976-
def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData, str]:
975+
def _parse_event_data(self, event_data_str: str) -> Tuple[str, Union[str, StreamEventData], str]:
977976
event_lines = event_data_str.strip().split("\n")
978-
event_type = None
977+
event_type: Optional[str] = None
979978
event_data = ""
980979
error_string = ""
981980

@@ -1001,44 +1000,44 @@ def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData,
10011000

10021001
# Map to the appropriate class instance
10031002
if event_type in {
1004-
AgentStreamEvent.THREAD_RUN_CREATED,
1005-
AgentStreamEvent.THREAD_RUN_QUEUED,
1006-
AgentStreamEvent.THREAD_RUN_IN_PROGRESS,
1007-
AgentStreamEvent.THREAD_RUN_REQUIRES_ACTION,
1008-
AgentStreamEvent.THREAD_RUN_COMPLETED,
1009-
AgentStreamEvent.THREAD_RUN_FAILED,
1010-
AgentStreamEvent.THREAD_RUN_CANCELLING,
1011-
AgentStreamEvent.THREAD_RUN_CANCELLED,
1012-
AgentStreamEvent.THREAD_RUN_EXPIRED,
1003+
AgentStreamEvent.THREAD_RUN_CREATED.value,
1004+
AgentStreamEvent.THREAD_RUN_QUEUED.value,
1005+
AgentStreamEvent.THREAD_RUN_IN_PROGRESS.value,
1006+
AgentStreamEvent.THREAD_RUN_REQUIRES_ACTION.value,
1007+
AgentStreamEvent.THREAD_RUN_COMPLETED.value,
1008+
AgentStreamEvent.THREAD_RUN_FAILED.value,
1009+
AgentStreamEvent.THREAD_RUN_CANCELLING.value,
1010+
AgentStreamEvent.THREAD_RUN_CANCELLED.value,
1011+
AgentStreamEvent.THREAD_RUN_EXPIRED.value,
10131012
}:
10141013
event_data_obj = _safe_instantiate(ThreadRun, parsed_data)
10151014
elif event_type in {
1016-
AgentStreamEvent.THREAD_RUN_STEP_CREATED,
1017-
AgentStreamEvent.THREAD_RUN_STEP_IN_PROGRESS,
1018-
AgentStreamEvent.THREAD_RUN_STEP_COMPLETED,
1019-
AgentStreamEvent.THREAD_RUN_STEP_FAILED,
1020-
AgentStreamEvent.THREAD_RUN_STEP_CANCELLED,
1021-
AgentStreamEvent.THREAD_RUN_STEP_EXPIRED,
1015+
AgentStreamEvent.THREAD_RUN_STEP_CREATED.value,
1016+
AgentStreamEvent.THREAD_RUN_STEP_IN_PROGRESS.value,
1017+
AgentStreamEvent.THREAD_RUN_STEP_COMPLETED.value,
1018+
AgentStreamEvent.THREAD_RUN_STEP_FAILED.value,
1019+
AgentStreamEvent.THREAD_RUN_STEP_CANCELLED.value,
1020+
AgentStreamEvent.THREAD_RUN_STEP_EXPIRED.value,
10221021
}:
10231022
event_data_obj = _safe_instantiate(RunStep, parsed_data)
10241023
elif event_type in {
1025-
AgentStreamEvent.THREAD_MESSAGE_CREATED,
1026-
AgentStreamEvent.THREAD_MESSAGE_IN_PROGRESS,
1027-
AgentStreamEvent.THREAD_MESSAGE_COMPLETED,
1028-
AgentStreamEvent.THREAD_MESSAGE_INCOMPLETE,
1024+
AgentStreamEvent.THREAD_MESSAGE_CREATED.value,
1025+
AgentStreamEvent.THREAD_MESSAGE_IN_PROGRESS.value,
1026+
AgentStreamEvent.THREAD_MESSAGE_COMPLETED.value,
1027+
AgentStreamEvent.THREAD_MESSAGE_INCOMPLETE.value,
10291028
}:
10301029
event_data_obj = _safe_instantiate(ThreadMessage, parsed_data)
1031-
elif event_type == AgentStreamEvent.THREAD_MESSAGE_DELTA:
1030+
elif event_type == AgentStreamEvent.THREAD_MESSAGE_DELTA.value:
10321031
event_data_obj = _safe_instantiate(MessageDeltaChunk, parsed_data)
1033-
elif event_type == AgentStreamEvent.THREAD_RUN_STEP_DELTA:
1032+
elif event_type == AgentStreamEvent.THREAD_RUN_STEP_DELTA.value:
10341033
event_data_obj = _safe_instantiate(RunStepDeltaChunk, parsed_data)
10351034
else:
1036-
event_data_obj = parsed_data
1035+
event_data_obj = ""
10371036
error_string = str(parsed_data)
10381037

10391038
return event_type, event_data_obj, error_string
10401039

1041-
async def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData]:
1040+
async def _process_event(self, event_data_str: str) -> Tuple[str, Union[str, StreamEventData]]:
10421041
event_type, event_data_obj, error_string = self._parse_event_data(event_data_str)
10431042

10441043
if (
@@ -1082,7 +1081,7 @@ async def until_done(self) -> None:
10821081
pass
10831082

10841083

1085-
class AgentRunStream(Iterator[Tuple[str, StreamEventData]]):
1084+
class AgentRunStream(Iterator[Tuple[str, Union[str, StreamEventData]]]):
10861085
def __init__(
10871086
self,
10881087
response_iterator: Iterator[bytes],
@@ -1106,7 +1105,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
11061105
def __iter__(self):
11071106
return self
11081107

1109-
def __next__(self) -> Tuple[str, StreamEventData]:
1108+
def __next__(self) -> Tuple[str, Union[str, StreamEventData]]:
11101109
if self.done:
11111110
raise StopIteration
11121111
while True:
@@ -1124,7 +1123,7 @@ def __next__(self) -> Tuple[str, StreamEventData]:
11241123
event_data_str, self.buffer = self.buffer.split("\n\n", 1)
11251124
return self._process_event(event_data_str)
11261125

1127-
def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData, str]:
1126+
def _parse_event_data(self, event_data_str: str) -> Tuple[str, Union[str, StreamEventData], str]:
11281127
event_lines = event_data_str.strip().split("\n")
11291128
event_type = None
11301129
event_data = ""
@@ -1150,44 +1149,44 @@ def _parse_event_data(self, event_data_str: str) -> Tuple[str, StreamEventData,
11501149

11511150
# Map to the appropriate class instance
11521151
if event_type in {
1153-
AgentStreamEvent.THREAD_RUN_CREATED,
1154-
AgentStreamEvent.THREAD_RUN_QUEUED,
1155-
AgentStreamEvent.THREAD_RUN_IN_PROGRESS,
1156-
AgentStreamEvent.THREAD_RUN_REQUIRES_ACTION,
1157-
AgentStreamEvent.THREAD_RUN_COMPLETED,
1158-
AgentStreamEvent.THREAD_RUN_FAILED,
1159-
AgentStreamEvent.THREAD_RUN_CANCELLING,
1160-
AgentStreamEvent.THREAD_RUN_CANCELLED,
1161-
AgentStreamEvent.THREAD_RUN_EXPIRED,
1152+
AgentStreamEvent.THREAD_RUN_CREATED.value,
1153+
AgentStreamEvent.THREAD_RUN_QUEUED.value,
1154+
AgentStreamEvent.THREAD_RUN_IN_PROGRESS.value,
1155+
AgentStreamEvent.THREAD_RUN_REQUIRES_ACTION.value,
1156+
AgentStreamEvent.THREAD_RUN_COMPLETED.value,
1157+
AgentStreamEvent.THREAD_RUN_FAILED.value,
1158+
AgentStreamEvent.THREAD_RUN_CANCELLING.value,
1159+
AgentStreamEvent.THREAD_RUN_CANCELLED.value,
1160+
AgentStreamEvent.THREAD_RUN_EXPIRED.value,
11621161
}:
11631162
event_data_obj = _safe_instantiate(ThreadRun, parsed_data)
11641163
elif event_type in {
1165-
AgentStreamEvent.THREAD_RUN_STEP_CREATED,
1166-
AgentStreamEvent.THREAD_RUN_STEP_IN_PROGRESS,
1167-
AgentStreamEvent.THREAD_RUN_STEP_COMPLETED,
1168-
AgentStreamEvent.THREAD_RUN_STEP_FAILED,
1169-
AgentStreamEvent.THREAD_RUN_STEP_CANCELLED,
1170-
AgentStreamEvent.THREAD_RUN_STEP_EXPIRED,
1164+
AgentStreamEvent.THREAD_RUN_STEP_CREATED.value,
1165+
AgentStreamEvent.THREAD_RUN_STEP_IN_PROGRESS.value,
1166+
AgentStreamEvent.THREAD_RUN_STEP_COMPLETED.value,
1167+
AgentStreamEvent.THREAD_RUN_STEP_FAILED.value,
1168+
AgentStreamEvent.THREAD_RUN_STEP_CANCELLED.value,
1169+
AgentStreamEvent.THREAD_RUN_STEP_EXPIRED.value,
11711170
}:
11721171
event_data_obj = _safe_instantiate(RunStep, parsed_data)
11731172
elif event_type in {
1174-
AgentStreamEvent.THREAD_MESSAGE_CREATED,
1175-
AgentStreamEvent.THREAD_MESSAGE_IN_PROGRESS,
1176-
AgentStreamEvent.THREAD_MESSAGE_COMPLETED,
1177-
AgentStreamEvent.THREAD_MESSAGE_INCOMPLETE,
1173+
AgentStreamEvent.THREAD_MESSAGE_CREATED.value,
1174+
AgentStreamEvent.THREAD_MESSAGE_IN_PROGRESS.value,
1175+
AgentStreamEvent.THREAD_MESSAGE_COMPLETED.value,
1176+
AgentStreamEvent.THREAD_MESSAGE_INCOMPLETE.value,
11781177
}:
11791178
event_data_obj = _safe_instantiate(ThreadMessage, parsed_data)
1180-
elif event_type == AgentStreamEvent.THREAD_MESSAGE_DELTA:
1179+
elif event_type == AgentStreamEvent.THREAD_MESSAGE_DELTA.value:
11811180
event_data_obj = _safe_instantiate(MessageDeltaChunk, parsed_data)
1182-
elif event_type == AgentStreamEvent.THREAD_RUN_STEP_DELTA:
1181+
elif event_type == AgentStreamEvent.THREAD_RUN_STEP_DELTA.value:
11831182
event_data_obj = _safe_instantiate(RunStepDeltaChunk, parsed_data)
11841183
else:
1185-
event_data_obj = parsed_data
1184+
event_data_obj = ""
11861185
error_string = str(parsed_data)
11871186

11881187
return event_type, event_data_obj, error_string
11891188

1190-
def _process_event(self, event_data_str: str) -> Tuple[str, StreamEventData]:
1189+
def _process_event(self, event_data_str: str) -> Tuple[str, Union[str, StreamEventData]]:
11911190
event_type, event_data_obj, error_string = self._parse_event_data(event_data_str)
11921191

11931192
if (

sdk/ai/azure-ai-projects/azure/ai/projects/operations/_patch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,14 @@ def _get_log_exporter(destination: Union[TextIO, str, None]) -> Any:
418418
# See: https://opentelemetry-python.readthedocs.io/en/latest/sdk/trace.export.html#opentelemetry.sdk.trace.export.ConsoleSpanExporter
419419
try:
420420
from opentelemetry.sdk._logs.export import ConsoleLogExporter
421+
return ConsoleLogExporter()
421422
except ModuleNotFoundError as ex:
422423
# since OTel logging is still in beta in Python, we're going to swallow any errors
423424
# and just warn about them.
424425
logger.warning(
425426
"Failed to configure OpenTelemetry logging.", exc_info=ex
426427
)
427-
428-
return ConsoleLogExporter()
428+
return None
429429
else:
430430
raise ValueError("Only `sys.stdout` is supported at the moment for type `TextIO`")
431431

sdk/ai/azure-ai-projects/azure/ai/projects/telemetry/agents/_ai_agents_instrumentor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
# pylint: disable = no-name-in-module
6464
from opentelemetry.trace import Span, StatusCode
6565

66-
from azure.core.tracing import AbstractSpan, SpanKind # type: ignore
66+
from azure.core.tracing import AbstractSpan # type: ignore
6767

6868
_tracing_library_available = True
6969
except ModuleNotFoundError:
@@ -1643,8 +1643,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
16431643

16441644
if self.last_run and self.last_run.last_error:
16451645
self.span.set_status(
1646-
StatusCode.ERROR, self.last_run.last_error.message
1647-
) # pyright: ignore [reportPossiblyUnboundVariable]
1646+
StatusCode.ERROR, # pyright: ignore [reportPossiblyUnboundVariable]
1647+
self.last_run.last_error.message
1648+
)
16481649
self.span.add_attribute(ERROR_TYPE, self.last_run.last_error.code)
16491650

16501651
self.span.__exit__(exc_type, exc_val, exc_tb)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"reportTypeCommentUsage": true,
3+
"reportMissingImports": false,
4+
"pythonVersion": "3.11",
5+
"exclude": [
6+
"**/downloaded",
7+
"**/sample_agents_vector_store_batch_enterprise_file_search_async.py",
8+
"**/sample_agents_with_file_search_attachment.py",
9+
"**/sample_agents_with_code_interpreter_file_attachment.py",
10+
"**/sample_agents_code_interpreter_attachment_enterprise_search.py",
11+
"**/sample_agents_with_file_search_attachment_async.py",
12+
"**/sample_agents_code_interpreter_attachment_enterprise_search_async.py",
13+
"**/sample_agents_code_interpreter_attachment_enterprise_search_async.py",
14+
"**/sample_agents_code_interpreter_attachment_async.py"
15+
]
16+
}

sdk/ai/azure-ai-projects/samples/agents/async_samples/sample_agents_vector_store_batch_enterprise_file_search_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def main():
9292
)
9393
print(f"Created run, run ID: {run.id}")
9494

95-
await file_search_tool.remove_vector_store(vector_store.id)
95+
file_search_tool.remove_vector_store(vector_store.id)
9696
print(
9797
f"Removed vector store from file search, vector store ID: {vector_store.id}"
9898
)

0 commit comments

Comments
 (0)