Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ce7e430
refactor: Made get_conn method of KiotaRequestAdapterHook async to fi…
dabla Aug 18, 2025
b83e34d
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 18, 2025
aa45e8f
refactor: Fixed static checks
dabla Aug 18, 2025
181dc56
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 18, 2025
f8ab73e
refactor: Re-added missing import HttpxRequestAdapter
dabla Aug 18, 2025
fea6885
refactor: Reformatted imports
dabla Aug 18, 2025
926de00
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 18, 2025
e350586
refactor: Fixed imports
dabla Aug 18, 2025
58f7e8a
refactor: Added missing AirflowProviderDeprecationWarning
dabla Aug 18, 2025
100f96e
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 18, 2025
3016455
refactor: Fixed test_serialize
dabla Aug 18, 2025
318abc5
refactor: Fixed test_throw_failed_responses_with_text_plain_content_t…
dabla Aug 18, 2025
89c59cb
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 19, 2025
b5e57b0
refactor: Renamed async get_conn method to get_async_conn in KiotaReq…
dabla Aug 19, 2025
01d1c96
refactor: Introduced async get_async_conn and keep original synced ge…
dabla Aug 19, 2025
1c60d9c
refactor: Introduced async get_async_conn and keep original synced ge…
dabla Aug 19, 2025
038adb9
refactor: Refactor PowerBI triggers like MSGraph
dabla Aug 19, 2025
f715566
refactor: Generate class names dynamically in PowerBI triggers
dabla Aug 19, 2025
9c1ee6c
refactor: Fixed some mypy issues in PowerBITrigger
dabla Aug 19, 2025
ffe503b
refactor: Fixed test_serialize
dabla Aug 19, 2025
92ef9c3
refactor: Reformatted test_get_conn
dabla Aug 19, 2025
06ecea7
refactor: Added docstring to BasePowerBITrigger
dabla Aug 19, 2025
ad3b63e
refactor: Replaced DeprecationWarning with AirflowProviderDeprecation…
dabla Aug 19, 2025
b1a434f
refactor: Fixed patching of 2 remaining tests
dabla Aug 19, 2025
715df1e
Merge branch 'main' into feature/msgraph-async-connection
dabla Aug 19, 2025
13f7026
refactor: Re-added deprecated get_conn methods in triggers
dabla Aug 20, 2025
da6ee26
refactor: Fixed static checks
dabla Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import warnings
from ast import literal_eval
from contextlib import suppress
from http import HTTPStatus
Expand All @@ -27,6 +28,7 @@
from urllib.parse import quote, urljoin, urlparse

import httpx
from asgiref.sync import sync_to_async
from azure.identity import CertificateCredential, ClientSecretCredential
from httpx import AsyncHTTPTransport, Response, Timeout
from kiota_abstractions.api_error import APIError
Expand All @@ -49,6 +51,7 @@
AirflowConfigException,
AirflowException,
AirflowNotFoundException,
AirflowProviderDeprecationWarning,
)
from airflow.providers.microsoft.azure.version_compat import BaseHook

Expand Down Expand Up @@ -114,7 +117,7 @@ class KiotaRequestAdapterHook(BaseHook):

DEFAULT_HEADERS = {"Accept": "application/json;q=1"}
DEFAULT_SCOPE = "https://graph.microsoft.com/.default"
cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {}
cached_request_adapters: dict[str, tuple[str, RequestAdapter]] = {}
conn_type: str = "msgraph"
conn_name_attr: str = "conn_id"
default_conn_name: str = "msgraph_default"
Expand All @@ -138,7 +141,7 @@ def __init__(
self.scopes = [scopes]
else:
self.scopes = scopes or [self.DEFAULT_SCOPE]
self._api_version = self.resolve_api_version_from_value(api_version)
self.api_version = self.resolve_api_version_from_value(api_version)

@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -186,11 +189,6 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
},
}

@property
def api_version(self) -> str | None:
self.get_conn() # Make sure config has been loaded through get_conn to have correct api version!
return self._api_version

@staticmethod
def resolve_api_version_from_value(
api_version: APIVersion | str, default: str | None = None
Expand All @@ -200,7 +198,7 @@ def resolve_api_version_from_value(
return api_version or default

def get_api_version(self, config: dict) -> str:
return self._api_version or self.resolve_api_version_from_value(
return self.api_version or self.resolve_api_version_from_value(
config.get("api_version"), APIVersion.v1.value
) # type: ignore

Expand All @@ -209,6 +207,13 @@ def get_host(self, connection: Connection) -> str:
return f"{connection.schema}://{connection.host}"
return self.host

def get_base_url(self, host: str, api_version: str, config: dict) -> str:
base_url = config.get("base_url", urljoin(host, api_version)).strip()

if not base_url.endswith("/"):
return f"{base_url}/"
return base_url

@staticmethod
def format_no_proxy_url(url: str) -> str:
if "://" not in url:
Expand Down Expand Up @@ -242,83 +247,112 @@ def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict | None:
return None
return proxies

def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
client_id = connection.login
client_secret = connection.password
config = connection.extra_dejson if connection.extra else {}
api_version = self.get_api_version(config)
host = self.get_host(connection) # type: ignore[arg-type]
base_url = self.get_base_url(host, api_version, config)
authority = config.get("authority")
proxies = self.get_proxies(config)
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
scopes = config.get("scopes", self.scopes)
if isinstance(scopes, str):
scopes = scopes.split(",")
verify = config.get("verify", True)
trust_env = config.get("trust_env", False)
allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")

self.log.info(
"Creating Microsoft Graph SDK client %s for conn_id: %s",
api_version,
self.conn_id,
)
self.log.info("Host: %s", host)
self.log.info("Base URL: %s", base_url)
self.log.info("Client id: %s", client_id)
self.log.info("Client secret: %s", client_secret)
self.log.info("API version: %s", api_version)
self.log.info("Scope: %s", scopes)
self.log.info("Verify: %s", verify)
self.log.info("Timeout: %s", self.timeout)
self.log.info("Trust env: %s", trust_env)
self.log.info("Authority: %s", authority)
self.log.info("Allowed hosts: %s", allowed_hosts)
self.log.info("Proxies: %s", proxies)
self.log.info("HTTPX Proxies: %s", httpx_proxies)
credentials = self.get_credentials(
login=connection.login,
password=connection.password,
config=config,
authority=authority,
verify=verify,
proxies=proxies,
)
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version,
client=httpx.AsyncClient(
mounts=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
base_url=base_url,
),
host=host,
)
auth_provider = AzureIdentityAuthenticationProvider(
credentials=credentials,
scopes=scopes,
allowed_hosts=allowed_hosts,
)
parse_node_factory = ParseNodeFactoryRegistry()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory()
request_adapter = HttpxRequestAdapter(
authentication_provider=auth_provider,
parse_node_factory=parse_node_factory,
http_client=http_client,
base_url=base_url,
)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
return api_version, request_adapter

def get_conn(self) -> RequestAdapter:
"""
Initiate a new RequestAdapter connection.

.. warning::
This method is deprecated.
"""
if not self.conn_id:
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")

warnings.warn(
"get_conn is deprecated, please use the async get_async_conn method!",
category=AirflowProviderDeprecationWarning,
stacklevel=2,
)

api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))

if not request_adapter:
connection = self.get_connection(conn_id=self.conn_id)
client_id = connection.login
client_secret = connection.password
config = connection.extra_dejson if connection.extra else {}
api_version = self.get_api_version(config)
host = self.get_host(connection) # type: ignore[arg-type]
base_url = config.get("base_url", urljoin(host, api_version))
authority = config.get("authority")
proxies = self.get_proxies(config)
httpx_proxies = self.to_httpx_proxies(proxies=proxies)
scopes = config.get("scopes", self.scopes)
if isinstance(scopes, str):
scopes = scopes.split(",")
verify = config.get("verify", True)
trust_env = config.get("trust_env", False)
allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",")

self.log.info(
"Creating Microsoft Graph SDK client %s for conn_id: %s",
api_version,
self.conn_id,
)
self.log.info("Host: %s", host)
self.log.info("Base URL: %s", base_url)
self.log.info("Client id: %s", client_id)
self.log.info("Client secret: %s", client_secret)
self.log.info("API version: %s", api_version)
self.log.info("Scope: %s", scopes)
self.log.info("Verify: %s", verify)
self.log.info("Timeout: %s", self.timeout)
self.log.info("Trust env: %s", trust_env)
self.log.info("Authority: %s", authority)
self.log.info("Allowed hosts: %s", allowed_hosts)
self.log.info("Proxies: %s", proxies)
self.log.info("HTTPX Proxies: %s", httpx_proxies)
credentials = self.get_credentials(
login=connection.login,
password=connection.password,
config=config,
authority=authority,
verify=verify,
proxies=proxies,
)
http_client = GraphClientFactory.create_with_default_middleware(
api_version=api_version,
client=httpx.AsyncClient(
mounts=httpx_proxies,
timeout=Timeout(timeout=self.timeout),
verify=verify,
trust_env=trust_env,
base_url=base_url,
),
host=host,
)
auth_provider = AzureIdentityAuthenticationProvider(
credentials=credentials,
scopes=scopes,
allowed_hosts=allowed_hosts,
)
parse_node_factory = ParseNodeFactoryRegistry()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory()
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory()
request_adapter = HttpxRequestAdapter(
authentication_provider=auth_provider,
parse_node_factory=parse_node_factory,
http_client=http_client,
base_url=base_url,
)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
self._api_version = api_version
api_version, request_adapter = self._build_request_adapter(connection)
self.api_version = api_version
return request_adapter

async def get_async_conn(self) -> RequestAdapter:
"""Initiate a new RequestAdapter connection asynchronously."""
if not self.conn_id:
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")

api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))

if not request_adapter:
connection = await sync_to_async(self.get_connection)(conn_id=self.conn_id)
api_version, request_adapter = self._build_request_adapter(connection)
self.api_version = api_version
return request_adapter

def get_proxies(self, config: dict) -> dict:
Expand Down Expand Up @@ -418,13 +452,15 @@ async def run(
return response

async def send_request(self, request_info: RequestInformation, response_type: str | None = None):
conn = await self.get_async_conn()

if response_type:
return await self.get_conn().send_primitive_async(
return await conn.send_primitive_async(
request_info=request_info,
response_type=response_type,
error_map=self.error_mapping(),
)
return await self.get_conn().send_no_response_content_async(
return await conn.send_no_response_content_async(
request_info=request_info,
error_map=self.error_mapping(),
)
Expand Down Expand Up @@ -468,7 +504,7 @@ def request_information(
header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json"
)
request_information.content = json.dumps(data).encode("utf-8")
print("Request Information:", request_information.url)
self.log.debug("Request Information: %s", request_information.url)
return request_information

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from collections.abc import AsyncIterator, Sequence
from contextlib import suppress
from datetime import datetime
from functools import cached_property
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -125,13 +126,11 @@ def __init__(
serializer: type[ResponseSerializer] = ResponseSerializer,
):
super().__init__()
self.hook = KiotaRequestAdapterHook(
conn_id=conn_id,
timeout=timeout,
proxies=proxies,
scopes=scopes,
api_version=api_version,
)
self.conn_id = conn_id
self.timeout = timeout
self.proxies = proxies
self.scopes = scopes
self.api_version = api_version
self.url = url
self.response_type = response_type
self.path_parameters = path_parameters
Expand All @@ -158,7 +157,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"conn_id": self.conn_id,
"timeout": self.timeout,
"proxies": self.proxies,
"scopes": self.hook.scopes,
"scopes": self.scopes,
"api_version": self.api_version,
"serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}",
"url": self.url,
Expand All @@ -173,23 +172,23 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
)

def get_conn(self) -> RequestAdapter:
return self.hook.get_conn()

@property
def conn_id(self) -> str:
return self.hook.conn_id
"""
Initiate a new RequestAdapter connection.

@property
def timeout(self) -> float | None:
return self.hook.timeout

@property
def proxies(self) -> dict | None:
return self.hook.proxies
.. warning::
This method is deprecated.
"""
return self.hook.get_conn()

@property
def api_version(self) -> APIVersion | str:
return self.hook.api_version
@cached_property
def hook(self) -> KiotaRequestAdapterHook:
return KiotaRequestAdapterHook(
conn_id=self.conn_id,
timeout=self.timeout,
proxies=self.proxies,
scopes=self.scopes,
api_version=self.api_version,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook."""
Expand Down
Loading
Loading