diff --git a/providers/microsoft/azure/README.rst b/providers/microsoft/azure/README.rst index 6ed58b2bdda50..0efaa9aaef721 100644 --- a/providers/microsoft/azure/README.rst +++ b/providers/microsoft/azure/README.rst @@ -75,11 +75,11 @@ PIP package Version required ``azure-mgmt-datafactory`` ``>=2.0.0`` ``azure-mgmt-containerregistry`` ``>=8.0.0`` ``azure-mgmt-containerinstance`` ``>=10.1.0`` -``msgraph-core`` ``>=1.0.0,!=1.1.8`` -``microsoft-kiota-http`` ``>=1.3.0,!=1.3.4`` -``microsoft-kiota-serialization-json`` ``==1.0.0`` -``microsoft-kiota-serialization-text`` ``==1.0.0`` -``microsoft-kiota-abstractions`` ``>=1.0.0,<1.4.0`` +``msgraph-core`` ``>=1.3.3`` +``microsoft-kiota-http`` ``>=1.8.0,<2.0.0`` +``microsoft-kiota-serialization-json`` ``>=1.8.0`` +``microsoft-kiota-serialization-text`` ``>=1.8.0`` +``microsoft-kiota-abstractions`` ``>=1.8.0,<2.0.0`` ``msal-extensions`` ``>=1.1.0`` ====================================== =================== diff --git a/providers/microsoft/azure/pyproject.toml b/providers/microsoft/azure/pyproject.toml index 406170bba3e31..a4b9959e24e5f 100644 --- a/providers/microsoft/azure/pyproject.toml +++ b/providers/microsoft/azure/pyproject.toml @@ -82,15 +82,15 @@ dependencies = [ "azure-mgmt-containerinstance>=10.1.0", # msgraph-core 1.1.8 has a bug which causes ABCMeta object is not subscriptable error # See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/781 - "msgraph-core>=1.0.0,!=1.1.8", + "msgraph-core>=1.3.3", # msgraph-core has transient import failures with microsoft-kiota-http==1.3.4 # See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/706 - "microsoft-kiota-http>=1.3.0,!=1.3.4", - "microsoft-kiota-serialization-json==1.0.0", - "microsoft-kiota-serialization-text==1.0.0", + "microsoft-kiota-http>=1.8.0,<2.0.0", + "microsoft-kiota-serialization-json>=1.8.0", + "microsoft-kiota-serialization-text>=1.8.0", # microsoft-kiota-abstractions 1.4.0 breaks MyPy static checks on main # see https://github.com/apache/airflow/issues/43036 - "microsoft-kiota-abstractions<1.4.0,>=1.0.0", + "microsoft-kiota-abstractions>=1.8.0,<2.0.0", "msal-extensions>=1.1.0", ] diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py index 4bb0d38e449f5..7838747194b10 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/get_provider_info.py @@ -482,11 +482,11 @@ def get_provider_info(): "azure-mgmt-datafactory>=2.0.0", "azure-mgmt-containerregistry>=8.0.0", "azure-mgmt-containerinstance>=10.1.0", - "msgraph-core>=1.0.0,!=1.1.8", - "microsoft-kiota-http>=1.3.0,!=1.3.4", - "microsoft-kiota-serialization-json==1.0.0", - "microsoft-kiota-serialization-text==1.0.0", - "microsoft-kiota-abstractions<1.4.0,>=1.0.0", + "msgraph-core>=1.3.3", + "microsoft-kiota-http>=1.8.0,<2.0.0", + "microsoft-kiota-serialization-json>=1.8.0", + "microsoft-kiota-serialization-text>=1.8.0", + "microsoft-kiota-abstractions>=1.8.0,<2.0.0", "msal-extensions>=1.1.0", ], "optional-dependencies": { diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 45a3c8f2c6f95..e0007570d2067 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -23,12 +23,12 @@ from http import HTTPStatus from io import BytesIO from json import JSONDecodeError -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from urllib.parse import quote, urljoin, urlparse import httpx from azure.identity import CertificateCredential, ClientSecretCredential -from httpx import AsyncHTTPTransport, Timeout +from httpx import AsyncHTTPTransport, Response, Timeout from kiota_abstractions.api_error import APIError from kiota_abstractions.method import Method from kiota_abstractions.request_information import RequestInformation @@ -55,10 +55,8 @@ if TYPE_CHECKING: from azure.identity._internal.client_credential_base import ClientCredentialBase from kiota_abstractions.request_adapter import RequestAdapter - from kiota_abstractions.request_information import QueryParams from kiota_abstractions.response_handler import NativeResponseType from kiota_abstractions.serialization import ParsableFactory - from kiota_http.httpx_request_adapter import ResponseType from airflow.models import Connection @@ -67,7 +65,7 @@ class DefaultResponseHandler(ResponseHandler): """DefaultResponseHandler returns JSON payload or content in bytes or response headers.""" @staticmethod - def get_value(response: NativeResponseType) -> Any: + def get_value(response: Response) -> Any: with suppress(JSONDecodeError): return response.json() content = response.content @@ -76,7 +74,7 @@ def get_value(response: NativeResponseType) -> Any: return content async def handle_response_async( - self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None + self, response: NativeResponseType, error_map: dict[str, ParsableFactory] | None ) -> Any: """ Invoke this callback method when a response is received. @@ -84,10 +82,11 @@ async def handle_response_async( param response: The type of the native response object. param error_map: The error dict to use in case of a failed request. """ - value = self.get_value(response) - if response.status_code not in {200, 201, 202, 204, 302}: - message = value or response.reason_phrase - status_code = HTTPStatus(response.status_code) + resp: Response = cast("Response", response) + value = self.get_value(resp) + if resp.status_code not in {200, 201, 202, 204, 302}: + message = value or resp.reason_phrase + status_code = HTTPStatus(resp.status_code) if status_code == HTTPStatus.BAD_REQUEST: raise AirflowBadRequest(message) elif status_code == HTTPStatus.NOT_FOUND: @@ -391,16 +390,16 @@ def test_connection(self): async def run( self, url: str = "", - response_type: ResponseType | None = None, + response_type: str | None = None, path_parameters: dict[str, Any] | None = None, method: str = "GET", - query_parameters: dict[str, QueryParams] | None = None, + query_parameters: dict[str, Any] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, ): self.log.info("Executing url '%s' as '%s'", url, method) - response = await self.get_conn().send_primitive_async( + response = await self.send_request( request_info=self.request_information( url=url, response_type=response_type, @@ -411,20 +410,31 @@ async def run( data=data, ), response_type=response_type, - error_map=self.error_mapping(), ) self.log.debug("response: %s", response) return response + async def send_request(self, request_info: RequestInformation, response_type: str | None = None): + if response_type: + return await self.get_conn().send_primitive_async( + request_info=request_info, + response_type=response_type, + error_map=self.error_mapping(), + ) + return await self.get_conn().send_no_response_content_async( + request_info=request_info, + error_map=self.error_mapping(), + ) + def request_information( self, url: str, - response_type: ResponseType | None = None, + response_type: str | None = None, path_parameters: dict[str, Any] | None = None, method: str = "GET", - query_parameters: dict[str, QueryParams] | None = None, + query_parameters: dict[str, Any] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, ) -> RequestInformation: @@ -446,8 +456,12 @@ def request_information( headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS for header_name, header_value in headers.items(): request_information.headers.try_add(header_name=header_name, header_value=header_value) - if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str): + if isinstance(data, BytesIO): + request_information.content = data.read() + elif isinstance(data, bytes): request_information.content = data + elif isinstance(data, str): + request_information.content = data.encode("utf-8") elif data: request_information.headers.try_add( header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" @@ -468,8 +482,8 @@ def encoded_query_parameters(query_parameters) -> dict: return {} @staticmethod - def error_mapping() -> dict[str, ParsableFactory | None]: + def error_mapping() -> dict[str, type[ParsableFactory]]: return { - "4XX": APIError, - "5XX": APIError, + "4XX": APIError, # type: ignore + "5XX": APIError, # type: ignore } diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py index 459683a69f4ae..e25fbc2f46314 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -38,8 +38,6 @@ if TYPE_CHECKING: from io import BytesIO - from kiota_abstractions.request_adapter import ResponseType - from kiota_abstractions.request_information import QueryParams from msgraph_core import APIVersion from airflow.utils.context import Context @@ -118,11 +116,11 @@ def __init__( self, *, url: str, - response_type: ResponseType | None = None, + response_type: str | None = None, path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", - query_parameters: dict[str, QueryParams] | None = None, + query_parameters: dict[str, Any] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py index 39e728a9de1d7..dd1cc17f4c637 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/sensors/msgraph.py @@ -31,8 +31,6 @@ from datetime import timedelta from io import BytesIO - from kiota_abstractions.request_information import QueryParams - from kiota_http.httpx_request_adapter import ResponseType from msgraph_core import APIVersion from airflow.utils.context import Context @@ -76,11 +74,11 @@ class MSGraphSensor(BaseSensorOperator): def __init__( self, url: str, - response_type: ResponseType | None = None, + response_type: str | None = None, path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", - query_parameters: dict[str, QueryParams] | None = None, + query_parameters: dict[str, Any] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py index 4006ee6c3c0bb..177a0cc6aac28 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -40,8 +40,6 @@ from io import BytesIO from kiota_abstractions.request_adapter import RequestAdapter - from kiota_abstractions.request_information import QueryParams - from kiota_http.httpx_request_adapter import ResponseType from msgraph_core import APIVersion @@ -112,11 +110,11 @@ class MSGraphTrigger(BaseTrigger): def __init__( self, url: str, - response_type: ResponseType | None = None, + response_type: str | None = None, path_parameters: dict[str, Any] | None = None, url_template: str | None = None, method: str = "GET", - query_parameters: dict[str, QueryParams] | None = None, + query_parameters: dict[str, Any] | None = None, headers: dict[str, str] | None = None, data: dict[str, Any] | str | BytesIO | None = None, conn_id: str = KiotaRequestAdapterHook.default_conn_name, diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index 9dc3a45452f81..6485b37ea64ea 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -20,7 +20,7 @@ import inspect from json import JSONDecodeError from os.path import dirname -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from unittest.mock import Mock, patch import pytest @@ -55,14 +55,28 @@ ) if TYPE_CHECKING: + from azure.identity._internal.msal_credentials import MsalCredential + from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider from kiota_abstractions.request_adapter import RequestAdapter + from kiota_authentication_azure.azure_identity_access_token_provider import ( + AzureIdentityAccessTokenProvider, + ) class TestKiotaRequestAdapterHook: @staticmethod def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str): - assert isinstance(request_adapter, HttpxRequestAdapter) - tenant_id = request_adapter._authentication_provider.access_token_provider._credentials._tenant_id + adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter) + auth_provider: BaseBearerTokenAuthenticationProvider = cast( + "BaseBearerTokenAuthenticationProvider", + adapter._authentication_provider, + ) + access_token_provider: AzureIdentityAccessTokenProvider = cast( + "AzureIdentityAccessTokenProvider", + auth_provider.access_token_provider, + ) + credentials: MsalCredential = cast("MsalCredential", access_token_provider._credentials) + tenant_id = credentials._tenant_id assert tenant_id == expected_tenant_id def test_get_conn(self):