Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions providers/microsoft/azure/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ PIP package Version required
``azure-mgmt-datafactory`` ``>=2.0.0``
``azure-mgmt-containerregistry`` ``>=8.0.0``
``azure-mgmt-containerinstance`` ``>=10.1.0``
``msgraph-core`` ``>=1.0.0,!=1.1.8``
``microsoft-kiota-http`` ``>=1.3.0,!=1.3.4``
``microsoft-kiota-serialization-json`` ``==1.0.0``
``microsoft-kiota-serialization-text`` ``==1.0.0``
``microsoft-kiota-abstractions`` ``>=1.0.0,<1.4.0``
``msgraph-core`` ``>=1.3.3``
``microsoft-kiota-http`` ``>=1.8.0,<2.0.0``
``microsoft-kiota-serialization-json`` ``>=1.8.0``
``microsoft-kiota-serialization-text`` ``>=1.8.0``
``microsoft-kiota-abstractions`` ``>=1.8.0,<2.0.0``
``msal-extensions`` ``>=1.1.0``
====================================== ===================

Expand Down
10 changes: 5 additions & 5 deletions providers/microsoft/azure/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ dependencies = [
"azure-mgmt-containerinstance>=10.1.0",
# msgraph-core 1.1.8 has a bug which causes ABCMeta object is not subscriptable error
# See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/781
"msgraph-core>=1.0.0,!=1.1.8",
"msgraph-core>=1.3.3",
# msgraph-core has transient import failures with microsoft-kiota-http==1.3.4
# See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/706
"microsoft-kiota-http>=1.3.0,!=1.3.4",
"microsoft-kiota-serialization-json==1.0.0",
"microsoft-kiota-serialization-text==1.0.0",
"microsoft-kiota-http>=1.8.0,<2.0.0",
"microsoft-kiota-serialization-json>=1.8.0",
"microsoft-kiota-serialization-text>=1.8.0",
# microsoft-kiota-abstractions 1.4.0 breaks MyPy static checks on main
# see https://github.com/apache/airflow/issues/43036
"microsoft-kiota-abstractions<1.4.0,>=1.0.0",
"microsoft-kiota-abstractions>=1.8.0,<2.0.0",
"msal-extensions>=1.1.0",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,11 @@ def get_provider_info():
"azure-mgmt-datafactory>=2.0.0",
"azure-mgmt-containerregistry>=8.0.0",
"azure-mgmt-containerinstance>=10.1.0",
"msgraph-core>=1.0.0,!=1.1.8",
"microsoft-kiota-http>=1.3.0,!=1.3.4",
"microsoft-kiota-serialization-json==1.0.0",
"microsoft-kiota-serialization-text==1.0.0",
"microsoft-kiota-abstractions<1.4.0,>=1.0.0",
"msgraph-core>=1.3.3",
"microsoft-kiota-http>=1.8.0,<2.0.0",
"microsoft-kiota-serialization-json>=1.8.0",
"microsoft-kiota-serialization-text>=1.8.0",
"microsoft-kiota-abstractions>=1.8.0,<2.0.0",
"msal-extensions>=1.1.0",
],
"optional-dependencies": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from http import HTTPStatus
from io import BytesIO
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from urllib.parse import quote, urljoin, urlparse

import httpx
from azure.identity import CertificateCredential, ClientSecretCredential
from httpx import AsyncHTTPTransport, Timeout
from httpx import AsyncHTTPTransport, Response, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
from kiota_abstractions.request_information import RequestInformation
Expand All @@ -55,10 +55,8 @@
if TYPE_CHECKING:
from azure.identity._internal.client_credential_base import ClientCredentialBase
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory
from kiota_http.httpx_request_adapter import ResponseType

from airflow.models import Connection

Expand All @@ -67,7 +65,7 @@ class DefaultResponseHandler(ResponseHandler):
"""DefaultResponseHandler returns JSON payload or content in bytes or response headers."""

@staticmethod
def get_value(response: NativeResponseType) -> Any:
def get_value(response: Response) -> Any:
with suppress(JSONDecodeError):
return response.json()
content = response.content
Expand All @@ -76,18 +74,19 @@ def get_value(response: NativeResponseType) -> Any:
return content

async def handle_response_async(
self, response: NativeResponseType, error_map: dict[str, ParsableFactory | None] | None = None
self, response: NativeResponseType, error_map: dict[str, ParsableFactory] | None
) -> Any:
"""
Invoke this callback method when a response is received.

param response: The type of the native response object.
param error_map: The error dict to use in case of a failed request.
"""
value = self.get_value(response)
if response.status_code not in {200, 201, 202, 204, 302}:
message = value or response.reason_phrase
status_code = HTTPStatus(response.status_code)
resp: Response = cast("Response", response)
value = self.get_value(resp)
if resp.status_code not in {200, 201, 202, 204, 302}:
message = value or resp.reason_phrase
status_code = HTTPStatus(resp.status_code)
if status_code == HTTPStatus.BAD_REQUEST:
raise AirflowBadRequest(message)
elif status_code == HTTPStatus.NOT_FOUND:
Expand Down Expand Up @@ -391,16 +390,16 @@ def test_connection(self):
async def run(
self,
url: str = "",
response_type: ResponseType | None = None,
response_type: str | None = None,
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
query_parameters: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
):
self.log.info("Executing url '%s' as '%s'", url, method)

response = await self.get_conn().send_primitive_async(
response = await self.send_request(
request_info=self.request_information(
url=url,
response_type=response_type,
Expand All @@ -411,20 +410,31 @@ async def run(
data=data,
),
response_type=response_type,
error_map=self.error_mapping(),
)

self.log.debug("response: %s", response)

return response

async def send_request(self, request_info: RequestInformation, response_type: str | None = None):
if response_type:
return await self.get_conn().send_primitive_async(
request_info=request_info,
response_type=response_type,
error_map=self.error_mapping(),
)
return await self.get_conn().send_no_response_content_async(
request_info=request_info,
error_map=self.error_mapping(),
)

def request_information(
self,
url: str,
response_type: ResponseType | None = None,
response_type: str | None = None,
path_parameters: dict[str, Any] | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
query_parameters: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
) -> RequestInformation:
Expand All @@ -446,8 +456,12 @@ def request_information(
headers = {**self.DEFAULT_HEADERS, **headers} if headers else self.DEFAULT_HEADERS
for header_name, header_value in headers.items():
request_information.headers.try_add(header_name=header_name, header_value=header_value)
if isinstance(data, BytesIO) or isinstance(data, bytes) or isinstance(data, str):
if isinstance(data, BytesIO):
request_information.content = data.read()
elif isinstance(data, bytes):
request_information.content = data
elif isinstance(data, str):
request_information.content = data.encode("utf-8")
elif data:
request_information.headers.try_add(
header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
Expand All @@ -468,8 +482,8 @@ def encoded_query_parameters(query_parameters) -> dict:
return {}

@staticmethod
def error_mapping() -> dict[str, ParsableFactory | None]:
def error_mapping() -> dict[str, type[ParsableFactory]]:
return {
"4XX": APIError,
"5XX": APIError,
"4XX": APIError, # type: ignore
"5XX": APIError, # type: ignore
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
if TYPE_CHECKING:
from io import BytesIO

from kiota_abstractions.request_adapter import ResponseType
from kiota_abstractions.request_information import QueryParams
from msgraph_core import APIVersion

from airflow.utils.context import Context
Expand Down Expand Up @@ -118,11 +116,11 @@ def __init__(
self,
*,
url: str,
response_type: ResponseType | None = None,
response_type: str | None = None,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
query_parameters: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
from datetime import timedelta
from io import BytesIO

from kiota_abstractions.request_information import QueryParams
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion

from airflow.utils.context import Context
Expand Down Expand Up @@ -76,11 +74,11 @@ class MSGraphSensor(BaseSensorOperator):
def __init__(
self,
url: str,
response_type: ResponseType | None = None,
response_type: str | None = None,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
query_parameters: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
from io import BytesIO

from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.request_information import QueryParams
from kiota_http.httpx_request_adapter import ResponseType
from msgraph_core import APIVersion


Expand Down Expand Up @@ -112,11 +110,11 @@ class MSGraphTrigger(BaseTrigger):
def __init__(
self,
url: str,
response_type: ResponseType | None = None,
response_type: str | None = None,
path_parameters: dict[str, Any] | None = None,
url_template: str | None = None,
method: str = "GET",
query_parameters: dict[str, QueryParams] | None = None,
query_parameters: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
data: dict[str, Any] | str | BytesIO | None = None,
conn_id: str = KiotaRequestAdapterHook.default_conn_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import inspect
from json import JSONDecodeError
from os.path import dirname
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -55,14 +55,28 @@
)

if TYPE_CHECKING:
from azure.identity._internal.msal_credentials import MsalCredential
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_authentication_azure.azure_identity_access_token_provider import (
AzureIdentityAccessTokenProvider,
)


class TestKiotaRequestAdapterHook:
@staticmethod
def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str):
assert isinstance(request_adapter, HttpxRequestAdapter)
tenant_id = request_adapter._authentication_provider.access_token_provider._credentials._tenant_id
adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter)
auth_provider: BaseBearerTokenAuthenticationProvider = cast(
"BaseBearerTokenAuthenticationProvider",
adapter._authentication_provider,
)
access_token_provider: AzureIdentityAccessTokenProvider = cast(
"AzureIdentityAccessTokenProvider",
auth_provider.access_token_provider,
)
credentials: MsalCredential = cast("MsalCredential", access_token_provider._credentials)
tenant_id = credentials._tenant_id
assert tenant_id == expected_tenant_id

def test_get_conn(self):
Expand Down