From b2319fd3bf291fe855ad5439fb233060f844e77f Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 26 Jan 2026 10:24:50 -0800 Subject: [PATCH 1/2] Use common provider's get_async_connection in other providers --- .../amazon/src/airflow/providers/amazon/aws/hooks/s3.py | 5 +++-- .../livy/src/airflow/providers/apache/livy/hooks/livy.py | 4 ++-- .../airflow/providers/cncf/kubernetes/hooks/kubernetes.py | 4 ++-- .../dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py | 3 ++- providers/http/src/airflow/providers/http/hooks/http.py | 4 ++-- .../airflow/providers/microsoft/azure/hooks/data_factory.py | 6 +++--- .../src/airflow/providers/microsoft/azure/hooks/wasb.py | 4 ++-- .../airflow/providers/pagerduty/hooks/pagerduty_events.py | 4 ++-- providers/sftp/src/airflow/providers/sftp/hooks/sftp.py | 4 ++-- providers/ssh/src/airflow/providers/ssh/hooks/ssh.py | 4 ++-- 10 files changed, 22 insertions(+), 20 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index 909423c06d3e9..db5ab1fc8d8d2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -42,6 +42,8 @@ from urllib.parse import urlsplit from uuid import uuid4 +from airflow.providers.common.compat.connection import get_async_connection + if TYPE_CHECKING: from aiobotocore.client import AioBaseClient from mypy_boto3_s3.service_resource import ( @@ -52,7 +54,6 @@ from airflow.providers.amazon.version_compat import ArgNotSet -from asgiref.sync import sync_to_async from boto3.s3.transfer import S3Transfer, TransferConfig from botocore.exceptions import ClientError @@ -90,7 +91,7 @@ async def maybe_add_bucket_name(*args, **kwargs): if not bound_args.arguments.get("bucket_name"): self = args[0] if self.aws_conn_id: - connection = await sync_to_async(self.get_connection)(self.aws_conn_id) + connection = await get_async_connection(self.aws_conn_id) if connection.schema: bound_args.arguments["bucket_name"] = connection.schema return bound_args diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 5844995cb0529..9de3582de1109 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -26,8 +26,8 @@ import aiohttp import requests from aiohttp import ClientResponseError -from asgiref.sync import sync_to_async +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook @@ -526,7 +526,7 @@ async def _do_api_call_async( auth = None if self.http_conn_id: - conn = await sync_to_async(self.get_connection)(self.http_conn_id) + conn = await get_async_connection(self.http_conn_id) self.base_url = self._generate_base_url(conn) # type: ignore[arg-type] if conn.login: diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index a76e9390db685..d39abe390c0ac 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -27,7 +27,6 @@ import aiofiles import requests -from asgiref.sync import sync_to_async from kubernetes import client, config, utils, watch from kubernetes.client.models import V1Deployment from kubernetes.config import ConfigException @@ -46,6 +45,7 @@ container_is_completed, container_is_running, ) +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook from airflow.utils import yaml @@ -885,7 +885,7 @@ async def api_client_from_kubeconfig_file(_kubeconfig_path: str | None): async def get_conn_extras(self) -> dict: if self._extras is None: if self.conn_id: - connection = await sync_to_async(self.get_connection)(self.conn_id) + connection = await get_async_connection(self.conn_id) self._extras = connection.extra_dejson else: self._extras = {} diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py index c4dd0572f6ef8..ca20480abd270 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py @@ -34,6 +34,7 @@ from requests.sessions import Session from tenacity import AsyncRetrying, RetryCallState, retry_if_exception, stop_after_attempt, wait_exponential +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.http.hooks.http import HttpHook @@ -161,7 +162,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: if bound_args.arguments.get("account_id") is None: self = args[0] if self.dbt_cloud_conn_id: - connection = await sync_to_async(self.get_connection)(self.dbt_cloud_conn_id) + connection = await get_async_connection(self.dbt_cloud_conn_id) default_account_id = connection.login if not default_account_id: raise AirflowException("Could not determine the dbt Cloud account.") diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py index 815b7ffadc910..ed137a651c426 100644 --- a/providers/http/src/airflow/providers/http/hooks/http.py +++ b/providers/http/src/airflow/providers/http/hooks/http.py @@ -25,13 +25,13 @@ import aiohttp import tenacity from aiohttp import ClientResponseError -from asgiref.sync import sync_to_async from requests import PreparedRequest, Request, Response, Session from requests.auth import HTTPBasicAuth from requests.exceptions import ConnectionError, HTTPError from requests.models import DEFAULT_REDIRECT_LIMIT from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException @@ -461,7 +461,7 @@ async def run( auth = None if self.http_conn_id: - conn = await sync_to_async(self.get_connection)(self.http_conn_id) + conn = await get_async_connection(self.http_conn_id) if conn.host and "://" in conn.host: self.base_url = conn.host diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py index 93c28e33300a5..44248847794e9 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/data_factory.py @@ -39,7 +39,6 @@ from functools import wraps from typing import IO, TYPE_CHECKING, Any, TypeVar, cast -from asgiref.sync import sync_to_async from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.identity.aio import ( ClientSecretCredential as AsyncClientSecretCredential, @@ -48,6 +47,7 @@ from azure.mgmt.datafactory import DataFactoryManagementClient from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, @@ -1089,7 +1089,7 @@ async def bind_argument(arg: Any, default_key: str) -> None: # Check if arg was not included in the function signature or, if it is, the value is not provided. if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] - conn = await sync_to_async(self.get_connection)(self.conn_id) + conn = await get_async_connection(self.conn_id) extras = conn.extra_dejson default_value = extras.get(default_key) or extras.get( f"extra__azure_data_factory__{default_key}" @@ -1126,7 +1126,7 @@ async def get_async_conn(self) -> AsyncDataFactoryManagementClient: if self._async_conn is not None: return self._async_conn - conn = await sync_to_async(self.get_connection)(self.conn_id) + conn = await get_async_connection(self.conn_id) extras = conn.extra_dejson tenant = get_field(extras, "tenantId") diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py index ed92f48f1d3ba..dabd7280e3d2d 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/wasb.py @@ -30,7 +30,6 @@ import os from typing import TYPE_CHECKING, Any, cast -from asgiref.sync import sync_to_async from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError from azure.identity import ClientSecretCredential from azure.identity.aio import ( @@ -44,6 +43,7 @@ ContainerClient as AsyncContainerClient, ) +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, @@ -620,7 +620,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: self._blob_service_client = cast("AsyncBlobServiceClient", self._blob_service_client) return self._blob_service_client - conn = await sync_to_async(self.get_connection)(self.conn_id) + conn = await get_async_connection(self.conn_id) extra = conn.extra_dejson or {} client_secret_auth_config = extra.pop("client_secret_auth_config", {}) diff --git a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py index 710189bde68f9..99a0efefe082b 100644 --- a/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py +++ b/providers/pagerduty/src/airflow/providers/pagerduty/hooks/pagerduty_events.py @@ -23,8 +23,8 @@ import aiohttp import pagerduty -from asgiref.sync import sync_to_async +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.providers.http.hooks.http import HttpAsyncHook @@ -285,7 +285,7 @@ async def get_integration_key(self) -> str: return self.integration_key if self.pagerduty_events_conn_id is not None: - conn = await sync_to_async(self.get_connection)(self.pagerduty_events_conn_id) + conn = await get_async_connection(self.pagerduty_events_conn_id) self.integration_key = conn.password if self.integration_key: return self.integration_key diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py index bef8d725c335d..80ceb729082a9 100644 --- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py @@ -33,10 +33,10 @@ from typing import IO, TYPE_CHECKING, Any, cast import asyncssh -from asgiref.sync import sync_to_async from paramiko.config import SSH_PORT from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook, Connection from airflow.providers.sftp.exceptions import ConnectionNotOpenedException from airflow.providers.ssh.hooks.ssh import SSHHook @@ -756,7 +756,7 @@ async def _get_conn(self) -> asyncssh.SSHClientConnection: - known_hosts - passphrase """ - conn = await sync_to_async(self.get_connection)(self.sftp_conn_id) + conn = await get_async_connection(self.sftp_conn_id) if conn.extra is not None: self._parse_extras(conn) # type: ignore[arg-type] diff --git a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py index 493f3f9236959..4814569e4cffd 100644 --- a/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py +++ b/providers/ssh/src/airflow/providers/ssh/hooks/ssh.py @@ -33,6 +33,7 @@ from sshtunnel import SSHTunnelForwarder from tenacity import Retrying, stop_after_attempt, wait_fixed, wait_random +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.utils.platform import getuser @@ -615,9 +616,8 @@ async def _get_conn(self): Returns an asyncssh SSHClientConnection that can be used to run commands. """ import asyncssh - from asgiref.sync import sync_to_async - conn = await sync_to_async(self.get_connection)(self.ssh_conn_id) + conn = await get_async_connection(self.ssh_conn_id) if conn.extra is not None: self._parse_extras(conn) From c33885a7065bfbebe4bd37d8fb3c23c988a15889 Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 26 Jan 2026 11:46:45 -0800 Subject: [PATCH 2/2] Fix sftp and livy unit tests --- .../tests/unit/apache/livy/hooks/test_livy.py | 12 ++++++------ .../sftp/tests/unit/sftp/hooks/test_sftp.py | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py index 8dac0bedf52db..3353d7969150a 100644 --- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py @@ -611,7 +611,7 @@ async def test_run_method_error(self, mock_do_api_call_async): @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for success response for POST method.""" @@ -634,7 +634,7 @@ async def mock_fun(arg1, arg2, arg3, arg4): @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for GET method.""" @@ -659,7 +659,7 @@ async def mock_fun(arg1, arg2, arg3, arg4): @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for PATCH method.""" @@ -684,7 +684,7 @@ async def mock_fun(arg1, arg2, arg3, arg4): @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for unexpected method error""" GET_RUN_ENDPOINT = "api/jobs/runs/get" @@ -700,7 +700,7 @@ async def test_do_api_call_async_unexpected_method_error(self, mock_get_connecti @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for TypeError.""" @@ -719,7 +719,7 @@ async def mock_fun(arg1, arg2, arg3, arg4): @pytest.mark.asyncio @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_connection") + @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session): """Asserts the _do_api_call_async for Client Response Error.""" diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index dc3838f1a8764..835bfcc022122 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -734,7 +734,7 @@ def __init__(self): class TestSFTPHookAsync: @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") @pytest.mark.asyncio async def test_extra_dejson_fields_for_connection_building_known_hosts_none( self, mock_get_connection, mock_connect, caplog @@ -775,7 +775,7 @@ async def test_extra_dejson_fields_for_connection_building_known_hosts_none( ) @patch("asyncssh.connect", new_callable=AsyncMock) @patch("asyncssh.import_private_key") - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") @pytest.mark.asyncio async def test_extra_dejson_fields_for_connection_with_host_key( self, @@ -799,7 +799,7 @@ async def test_extra_dejson_fields_for_connection_with_host_key( assert hook.known_hosts == f"localhost {mock_host_key}".encode() @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") @pytest.mark.asyncio async def test_extra_dejson_fields_for_connection_raises_valuerror( self, mock_get_connection, mock_connect @@ -820,7 +820,7 @@ async def test_extra_dejson_fields_for_connection_raises_valuerror( @patch("paramiko.SSHClient.connect") @patch("asyncssh.import_private_key") @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") @pytest.mark.asyncio async def test_no_host_key_check_set_logs_warning( self, mock_get_connection, mock_connect, mock_import_pkey, mock_ssh_connect, caplog @@ -833,7 +833,7 @@ async def test_no_host_key_check_set_logs_warning( assert "No Host Key Verification. This won't protect against Man-In-The-Middle attacks" in caplog.text @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") @pytest.mark.asyncio async def test_extra_dejson_fields_for_connection_building(self, mock_get_connection, mock_connect): """ @@ -861,7 +861,7 @@ async def test_extra_dejson_fields_for_connection_building(self, mock_get_connec @pytest.mark.asyncio @patch("asyncssh.connect", new_callable=AsyncMock) @patch("asyncssh.import_private_key") - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") async def test_connection_private(self, mock_get_connection, mock_import_private_key, mock_connect): """ Assert that connection details with private key passed through the extra field in the Airflow connection @@ -888,7 +888,7 @@ async def test_connection_private(self, mock_get_connection, mock_import_private @pytest.mark.asyncio @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") async def test_connection_port_default_to_22(self, mock_get_connection, mock_connect): from unittest.mock import Mock, call @@ -917,7 +917,7 @@ async def test_connection_port_default_to_22(self, mock_get_connection, mock_con @pytest.mark.asyncio @patch("asyncssh.connect", new_callable=AsyncMock) - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.get_connection") + @patch("airflow.providers.sftp.hooks.sftp.get_async_connection") async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect): from unittest.mock import Mock, call