diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index f142ea23624e1..958a3b8f437e7 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import warnings from ast import literal_eval from contextlib import suppress from http import HTTPStatus @@ -27,6 +28,7 @@ from urllib.parse import quote, urljoin, urlparse import httpx +from asgiref.sync import sync_to_async from azure.identity import CertificateCredential, ClientSecretCredential from httpx import AsyncHTTPTransport, Response, Timeout from kiota_abstractions.api_error import APIError @@ -49,6 +51,7 @@ AirflowConfigException, AirflowException, AirflowNotFoundException, + AirflowProviderDeprecationWarning, ) from airflow.providers.microsoft.azure.version_compat import BaseHook @@ -114,7 +117,7 @@ class KiotaRequestAdapterHook(BaseHook): DEFAULT_HEADERS = {"Accept": "application/json;q=1"} DEFAULT_SCOPE = "https://graph.microsoft.com/.default" - cached_request_adapters: dict[str, tuple[APIVersion, RequestAdapter]] = {} + cached_request_adapters: dict[str, tuple[str, RequestAdapter]] = {} conn_type: str = "msgraph" conn_name_attr: str = "conn_id" default_conn_name: str = "msgraph_default" @@ -138,7 +141,7 @@ def __init__( self.scopes = [scopes] else: self.scopes = scopes or [self.DEFAULT_SCOPE] - self._api_version = self.resolve_api_version_from_value(api_version) + self.api_version = self.resolve_api_version_from_value(api_version) @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: @@ -186,11 +189,6 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: }, } - @property - def api_version(self) -> str | None: - self.get_conn() # Make sure config has been loaded through get_conn to have correct api version! - return self._api_version - @staticmethod def resolve_api_version_from_value( api_version: APIVersion | str, default: str | None = None @@ -200,7 +198,7 @@ def resolve_api_version_from_value( return api_version or default def get_api_version(self, config: dict) -> str: - return self._api_version or self.resolve_api_version_from_value( + return self.api_version or self.resolve_api_version_from_value( config.get("api_version"), APIVersion.v1.value ) # type: ignore @@ -209,6 +207,13 @@ def get_host(self, connection: Connection) -> str: return f"{connection.schema}://{connection.host}" return self.host + def get_base_url(self, host: str, api_version: str, config: dict) -> str: + base_url = config.get("base_url", urljoin(host, api_version)).strip() + + if not base_url.endswith("/"): + return f"{base_url}/" + return base_url + @staticmethod def format_no_proxy_url(url: str) -> str: if "://" not in url: @@ -242,83 +247,112 @@ def to_msal_proxies(self, authority: str | None, proxies: dict) -> dict | None: return None return proxies + def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]: + client_id = connection.login + client_secret = connection.password + config = connection.extra_dejson if connection.extra else {} + api_version = self.get_api_version(config) + host = self.get_host(connection) # type: ignore[arg-type] + base_url = self.get_base_url(host, api_version, config) + authority = config.get("authority") + proxies = self.get_proxies(config) + httpx_proxies = self.to_httpx_proxies(proxies=proxies) + scopes = config.get("scopes", self.scopes) + if isinstance(scopes, str): + scopes = scopes.split(",") + verify = config.get("verify", True) + trust_env = config.get("trust_env", False) + allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",") + + self.log.info( + "Creating Microsoft Graph SDK client %s for conn_id: %s", + api_version, + self.conn_id, + ) + self.log.info("Host: %s", host) + self.log.info("Base URL: %s", base_url) + self.log.info("Client id: %s", client_id) + self.log.info("Client secret: %s", client_secret) + self.log.info("API version: %s", api_version) + self.log.info("Scope: %s", scopes) + self.log.info("Verify: %s", verify) + self.log.info("Timeout: %s", self.timeout) + self.log.info("Trust env: %s", trust_env) + self.log.info("Authority: %s", authority) + self.log.info("Allowed hosts: %s", allowed_hosts) + self.log.info("Proxies: %s", proxies) + self.log.info("HTTPX Proxies: %s", httpx_proxies) + credentials = self.get_credentials( + login=connection.login, + password=connection.password, + config=config, + authority=authority, + verify=verify, + proxies=proxies, + ) + http_client = GraphClientFactory.create_with_default_middleware( + api_version=api_version, + client=httpx.AsyncClient( + mounts=httpx_proxies, + timeout=Timeout(timeout=self.timeout), + verify=verify, + trust_env=trust_env, + base_url=base_url, + ), + host=host, + ) + auth_provider = AzureIdentityAuthenticationProvider( + credentials=credentials, + scopes=scopes, + allowed_hosts=allowed_hosts, + ) + parse_node_factory = ParseNodeFactoryRegistry() + parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory() + parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory() + request_adapter = HttpxRequestAdapter( + authentication_provider=auth_provider, + parse_node_factory=parse_node_factory, + http_client=http_client, + base_url=base_url, + ) + self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) + return api_version, request_adapter + def get_conn(self) -> RequestAdapter: + """ + Initiate a new RequestAdapter connection. + + .. warning:: + This method is deprecated. + """ if not self.conn_id: raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!") + warnings.warn( + "get_conn is deprecated, please use the async get_async_conn method!", + category=AirflowProviderDeprecationWarning, + stacklevel=2, + ) + api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None)) if not request_adapter: connection = self.get_connection(conn_id=self.conn_id) - client_id = connection.login - client_secret = connection.password - config = connection.extra_dejson if connection.extra else {} - api_version = self.get_api_version(config) - host = self.get_host(connection) # type: ignore[arg-type] - base_url = config.get("base_url", urljoin(host, api_version)) - authority = config.get("authority") - proxies = self.get_proxies(config) - httpx_proxies = self.to_httpx_proxies(proxies=proxies) - scopes = config.get("scopes", self.scopes) - if isinstance(scopes, str): - scopes = scopes.split(",") - verify = config.get("verify", True) - trust_env = config.get("trust_env", False) - allowed_hosts = (config.get("allowed_hosts", authority) or "").split(",") - - self.log.info( - "Creating Microsoft Graph SDK client %s for conn_id: %s", - api_version, - self.conn_id, - ) - self.log.info("Host: %s", host) - self.log.info("Base URL: %s", base_url) - self.log.info("Client id: %s", client_id) - self.log.info("Client secret: %s", client_secret) - self.log.info("API version: %s", api_version) - self.log.info("Scope: %s", scopes) - self.log.info("Verify: %s", verify) - self.log.info("Timeout: %s", self.timeout) - self.log.info("Trust env: %s", trust_env) - self.log.info("Authority: %s", authority) - self.log.info("Allowed hosts: %s", allowed_hosts) - self.log.info("Proxies: %s", proxies) - self.log.info("HTTPX Proxies: %s", httpx_proxies) - credentials = self.get_credentials( - login=connection.login, - password=connection.password, - config=config, - authority=authority, - verify=verify, - proxies=proxies, - ) - http_client = GraphClientFactory.create_with_default_middleware( - api_version=api_version, - client=httpx.AsyncClient( - mounts=httpx_proxies, - timeout=Timeout(timeout=self.timeout), - verify=verify, - trust_env=trust_env, - base_url=base_url, - ), - host=host, - ) - auth_provider = AzureIdentityAuthenticationProvider( - credentials=credentials, - scopes=scopes, - allowed_hosts=allowed_hosts, - ) - parse_node_factory = ParseNodeFactoryRegistry() - parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory() - parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory() - request_adapter = HttpxRequestAdapter( - authentication_provider=auth_provider, - parse_node_factory=parse_node_factory, - http_client=http_client, - base_url=base_url, - ) - self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) - self._api_version = api_version + api_version, request_adapter = self._build_request_adapter(connection) + self.api_version = api_version + return request_adapter + + async def get_async_conn(self) -> RequestAdapter: + """Initiate a new RequestAdapter connection asynchronously.""" + if not self.conn_id: + raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!") + + api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None)) + + if not request_adapter: + connection = await sync_to_async(self.get_connection)(conn_id=self.conn_id) + api_version, request_adapter = self._build_request_adapter(connection) + self.api_version = api_version return request_adapter def get_proxies(self, config: dict) -> dict: @@ -418,13 +452,15 @@ async def run( return response async def send_request(self, request_info: RequestInformation, response_type: str | None = None): + conn = await self.get_async_conn() + if response_type: - return await self.get_conn().send_primitive_async( + return await conn.send_primitive_async( request_info=request_info, response_type=response_type, error_map=self.error_mapping(), ) - return await self.get_conn().send_no_response_content_async( + return await conn.send_no_response_content_async( request_info=request_info, error_map=self.error_mapping(), ) @@ -468,7 +504,7 @@ def request_information( header_name=RequestInformation.CONTENT_TYPE_HEADER, header_value="application/json" ) request_information.content = json.dumps(data).encode("utf-8") - print("Request Information:", request_information.url) + self.log.debug("Request Information: %s", request_information.url) return request_information @staticmethod diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py index 177a0cc6aac28..94cef0eeaf6a2 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/msgraph.py @@ -23,6 +23,7 @@ from collections.abc import AsyncIterator, Sequence from contextlib import suppress from datetime import datetime +from functools import cached_property from json import JSONDecodeError from typing import ( TYPE_CHECKING, @@ -125,13 +126,11 @@ def __init__( serializer: type[ResponseSerializer] = ResponseSerializer, ): super().__init__() - self.hook = KiotaRequestAdapterHook( - conn_id=conn_id, - timeout=timeout, - proxies=proxies, - scopes=scopes, - api_version=api_version, - ) + self.conn_id = conn_id + self.timeout = timeout + self.proxies = proxies + self.scopes = scopes + self.api_version = api_version self.url = url self.response_type = response_type self.path_parameters = path_parameters @@ -158,7 +157,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "conn_id": self.conn_id, "timeout": self.timeout, "proxies": self.proxies, - "scopes": self.hook.scopes, + "scopes": self.scopes, "api_version": self.api_version, "serializer": f"{self.serializer.__class__.__module__}.{self.serializer.__class__.__name__}", "url": self.url, @@ -173,23 +172,23 @@ def serialize(self) -> tuple[str, dict[str, Any]]: ) def get_conn(self) -> RequestAdapter: - return self.hook.get_conn() - - @property - def conn_id(self) -> str: - return self.hook.conn_id + """ + Initiate a new RequestAdapter connection. - @property - def timeout(self) -> float | None: - return self.hook.timeout - - @property - def proxies(self) -> dict | None: - return self.hook.proxies + .. warning:: + This method is deprecated. + """ + return self.hook.get_conn() - @property - def api_version(self) -> APIVersion | str: - return self.hook.api_version + @cached_property + def hook(self) -> KiotaRequestAdapterHook: + return KiotaRequestAdapterHook( + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + scopes=self.scopes, + api_version=self.api_version, + ) async def run(self) -> AsyncIterator[TriggerEvent]: """Make a series of asynchronous HTTP calls via a KiotaRequestAdapterHook.""" diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/powerbi.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/powerbi.py index 042916f796637..c9198474cbb0b 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/powerbi.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/triggers/powerbi.py @@ -20,6 +20,7 @@ import asyncio import time from collections.abc import AsyncIterator +from functools import cached_property from typing import TYPE_CHECKING, Any import tenacity @@ -32,10 +33,56 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: + from kiota_abstractions.request_adapter import RequestAdapter from msgraph_core import APIVersion -class PowerBITrigger(BaseTrigger): +class BasePowerBITrigger(BaseTrigger): + """ + Base class for all PowerBI related triggers. + + :param conn_id: The connection Id to connect to PowerBI. + :param timeout: The HTTP timeout being used by the `KiotaRequestAdapter` (default is None). + When no timeout is specified or set to None then there is no HTTP timeout on each request. + :param proxies: A dict defining the HTTP proxies to be used (default is None). + :param api_version: The API version of the Microsoft Graph API to be used (default is v1). + You can pass an enum named APIVersion which has 2 possible members v1 and beta, + or you can pass a string as `v1.0` or `beta`. + """ + + def __init__( + self, + conn_id: str, + timeout: float = 60 * 60 * 24 * 7, + proxies: dict | None = None, + api_version: APIVersion | str | None = None, + ): + super().__init__() + self.conn_id = conn_id + self.timeout = timeout + self.proxies = proxies + self.api_version = api_version + + def get_conn(self) -> RequestAdapter: + """ + Initiate a new RequestAdapter connection. + + .. warning:: + This method is deprecated. + """ + return self.hook.get_conn() + + @cached_property + def hook(self) -> PowerBIHook: + return PowerBIHook( + conn_id=self.conn_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + ) + + +class PowerBITrigger(BasePowerBITrigger): """ Triggers when Power BI dataset refresh is completed. @@ -69,11 +116,9 @@ def __init__( wait_for_termination: bool = True, request_body: dict[str, Any] | None = None, ): - super().__init__() - self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) + super().__init__(conn_id=conn_id, timeout=timeout, proxies=proxies, api_version=api_version) self.dataset_id = dataset_id self.dataset_refresh_id = dataset_refresh_id - self.timeout = timeout self.group_id = group_id self.check_interval = check_interval self.wait_for_termination = wait_for_termination @@ -82,7 +127,7 @@ def __init__( def serialize(self): """Serialize the trigger instance.""" return ( - "airflow.providers.microsoft.azure.triggers.powerbi.PowerBITrigger", + f"{self.__class__.__module__}.{self.__class__.__name__}", { "conn_id": self.conn_id, "proxies": self.proxies, @@ -97,18 +142,6 @@ def serialize(self): }, ) - @property - def conn_id(self) -> str: - return self.hook.conn_id - - @property - def proxies(self) -> dict | None: - return self.hook.proxies - - @property - def api_version(self) -> APIVersion | str: - return self.hook.api_version - async def run(self) -> AsyncIterator[TriggerEvent]: """Make async connection to the PowerBI and polls for the dataset refresh status.""" if not self.dataset_refresh_id: @@ -236,7 +269,7 @@ async def fetch_refresh_status_and_error() -> tuple[str, str]: ) -class PowerBIWorkspaceListTrigger(BaseTrigger): +class PowerBIWorkspaceListTrigger(BasePowerBITrigger): """ Triggers a call to the API to request the available workspace IDs. @@ -257,15 +290,13 @@ def __init__( proxies: dict | None = None, api_version: APIVersion | str | None = None, ): - super().__init__() - self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) - self.timeout = timeout + super().__init__(conn_id=conn_id, timeout=timeout, proxies=proxies, api_version=api_version) self.workspace_ids = workspace_ids def serialize(self): """Serialize the trigger instance.""" return ( - "airflow.providers.microsoft.azure.triggers.powerbi.PowerBIWorkspaceListTrigger", + f"{self.__class__.__module__}.{self.__class__.__name__}", { "conn_id": self.conn_id, "proxies": self.proxies, @@ -275,18 +306,6 @@ def serialize(self): }, ) - @property - def conn_id(self) -> str: - return self.hook.conn_id - - @property - def proxies(self) -> dict | None: - return self.hook.proxies - - @property - def api_version(self) -> APIVersion | str: - return self.hook.api_version - async def run(self) -> AsyncIterator[TriggerEvent]: """Make async connection to the PowerBI and polls for the list of workspace IDs.""" # Trigger the API to get the workspace list @@ -313,7 +332,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: return -class PowerBIDatasetListTrigger(BaseTrigger): +class PowerBIDatasetListTrigger(BasePowerBITrigger): """ Triggers a call to the API to request the available dataset IDs. @@ -336,16 +355,14 @@ def __init__( proxies: dict | None = None, api_version: APIVersion | str | None = None, ): - super().__init__() - self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) - self.timeout = timeout + super().__init__(conn_id=conn_id, timeout=timeout, proxies=proxies, api_version=api_version) self.group_id = group_id self.dataset_ids = dataset_ids def serialize(self): """Serialize the trigger instance.""" return ( - "airflow.providers.microsoft.azure.triggers.powerbi.PowerBIDatasetListTrigger", + f"{self.__class__.__module__}.{self.__class__.__name__}", { "conn_id": self.conn_id, "proxies": self.proxies, @@ -356,18 +373,6 @@ def serialize(self): }, ) - @property - def conn_id(self) -> str: - return self.hook.conn_id - - @property - def proxies(self) -> dict | None: - return self.hook.proxies - - @property - def api_version(self) -> APIVersion | str: - return self.hook.api_version - async def run(self) -> AsyncIterator[TriggerEvent]: """Make async connection to the PowerBI and polls for the list of dataset IDs.""" # Trigger the API to get the dataset list diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index 6e06b61f931e5..affe190b35925 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -95,12 +95,27 @@ def test_get_conn(self): side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + + with pytest.warns(DeprecationWarning, match="get_conn is deprecated"): + actual = hook.get_conn() assert isinstance(actual, HttpxRequestAdapter) assert actual.base_url == "https://graph.microsoft.com/v1.0/" - def test_get_conn_with_custom_base_url(self): + @pytest.mark.asyncio + async def test_get_async_conn(self): + with patch( + f"{BASEHOOK_PATCH_PATH}.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + actual = await hook.get_async_conn() + + assert isinstance(actual, HttpxRequestAdapter) + assert actual.base_url == "https://graph.microsoft.com/v1.0/" + + @pytest.mark.asyncio + async def test_get_async_conn_with_custom_base_url(self): connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, host="api.fabric.microsoft.com", @@ -112,12 +127,13 @@ def test_get_conn_with_custom_base_url(self): side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + actual = await hook.get_async_conn() assert isinstance(actual, HttpxRequestAdapter) assert actual.base_url == "https://api.fabric.microsoft.com/v1/" - def test_get_conn_with_proxies_as_string(self): + @pytest.mark.asyncio + async def test_get_async_conn_with_proxies_as_string(self): connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, host="api.fabric.microsoft.com", @@ -130,13 +146,14 @@ def test_get_conn_with_proxies_as_string(self): side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + actual = await hook.get_async_conn() assert isinstance(actual, HttpxRequestAdapter) assert actual._http_client._mounts.get(URLPattern("http://")) assert actual._http_client._mounts.get(URLPattern("https://")) - def test_get_conn_with_proxies_as_invalid_string(self): + @pytest.mark.asyncio + async def test_get_async_conn_with_proxies_as_invalid_string(self): connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, host="api.fabric.microsoft.com", @@ -151,9 +168,10 @@ def test_get_conn_with_proxies_as_invalid_string(self): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") with pytest.raises(AirflowConfigException): - hook.get_conn() + await hook.get_async_conn() - def test_get_conn_with_proxies_as_json(self): + @pytest.mark.asyncio + async def test_get_async_conn_with_proxies_as_json(self): connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, host="api.fabric.microsoft.com", @@ -166,7 +184,7 @@ def test_get_conn_with_proxies_as_json(self): side_effect=connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + actual = await hook.get_async_conn() assert isinstance(actual, HttpxRequestAdapter) assert actual._http_client._mounts.get(URLPattern("http://")) @@ -208,10 +226,19 @@ def test_api_version(self): f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): - hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version=APIVersion.v1.value) assert hook.api_version == APIVersion.v1.value + def test_api_version_when_none_is_explicitly_passed_as_api_version(self): + with patch( + f"{BASEHOOK_PATCH_PATH}.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api", api_version=None) + + assert not hook.api_version + def test_get_api_version_when_empty_config_dict(self): with patch( f"{BASEHOOK_PATCH_PATH}.get_connection", @@ -264,17 +291,19 @@ def test_get_host_when_connection_has_no_scheme_or_host(self): assert actual == NationalClouds.Global.value - def test_tenant_id(self): + @pytest.mark.asyncio + async def test_tenant_id(self): with patch( f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=get_airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + actual = await hook.get_async_conn() self.assert_tenant_id(actual, "tenant-id") - def test_azure_tenant_id(self): + @pytest.mark.asyncio + async def test_azure_tenant_id(self): airflow_connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, azure_tenant_id="azure-tenant-id", @@ -285,7 +314,7 @@ def test_azure_tenant_id(self): side_effect=airflow_connection, ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = hook.get_conn() + actual = await hook.get_async_conn() self.assert_tenant_id(actual, "azure-tenant-id") @@ -296,7 +325,8 @@ def test_encoded_query_parameters(self): assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000} - def test_request_information_with_custom_host(self): + @pytest.mark.asyncio + async def test_request_information_with_custom_host(self): connection = lambda conn_id: get_airflow_connection( conn_id=conn_id, host="api.fabric.microsoft.com", @@ -309,7 +339,7 @@ def test_request_information_with_custom_host(self): ): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") request_info = hook.request_information(url="myorg/admin/apps", query_parameters={"$top": 5000}) - request_adapter = hook.get_conn() + request_adapter = await hook.get_async_conn() request_adapter.set_base_url_for_request_information(request_info) assert isinstance(request_info, RequestInformation) @@ -330,7 +360,8 @@ async def test_throw_failed_responses_with_text_plain_content_type(self): response.is_success = False span = Mock(spec=Span) - actual = await hook.get_conn().get_root_parse_node(response, span, span) + conn = await hook.get_async_conn() + actual = await conn.get_root_parse_node(response, span, span) assert isinstance(actual, TextParseNode) assert actual.get_str_value() == "TenantThrottleThresholdExceeded" @@ -349,7 +380,8 @@ async def test_throw_failed_responses_with_application_json_content_type(self): response.is_success = False span = Mock(spec=Span) - actual = await hook.get_conn().get_root_parse_node(response, span, span) + conn = await hook.get_async_conn() + actual = await conn.get_root_parse_node(response, span, span) assert isinstance(actual, JsonParseNode) error_code = actual.get_child_node("error").get_child_node("code").get_str_value() diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py index 978ad15892e2c..71010e5af0987 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/triggers/test_msgraph.py @@ -118,7 +118,13 @@ def test_serialize(self): side_effect=get_airflow_connection, ): url = "https://graph.microsoft.com/v1.0/me/drive/items" - trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api") + trigger = MSGraphTrigger( + url, + response_type="bytes", + conn_id="msgraph_api", + scopes=[KiotaRequestAdapterHook.DEFAULT_SCOPE], + api_version=APIVersion.v1.value, + ) actual = trigger.serialize()