diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 7aadbc5e935be..a9205d9bfbf21 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -50,6 +50,7 @@ classifiers = [ ] dependencies = [ "apache-airflow-core<3.3.0,>=3.2.0", + "asgiref>=2.3.0", "attrs>=24.2.0, !=25.2.0", "fsspec>=2023.10.0", "httpx>=0.27.0", diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 570cd25d9a3ef..2107cd853ab63 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -172,13 +172,33 @@ def _get_connection(conn_id: str) -> Connection: async def _async_get_connection(conn_id: str) -> Connection: - # TODO: add async support for secrets backends + from asgiref.sync import sync_to_async from airflow.sdk.execution_time.comms import GetConnection + from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) + # TODO: check cache first + # enabled only if SecretCache.init() has been called first + # Try secrets backends first using async wrapper + backends = ensure_secrets_backend_loaded() + for secrets_backend in backends: + try: + conn = await sync_to_async(secrets_backend.get_connection)(conn_id) + if conn: + # TODO: this should probably be in get conn + if conn.password: + mask_secret(conn.password) + if conn.extra: + mask_secret(conn.extra) + return conn + except Exception: + # If one backend fails, try the next one + continue + + # If no secrets backend has the connection, fall back to API server + msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) return _process_connection_result_conn(msg) diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 54e2c66bee82f..39aecb9d95534 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -58,6 +58,7 @@ TriggeringAssetEventsAccessor, VariableAccessor, _AssetRefResolutionMixin, + _async_get_connection, _convert_variable_result_to_variable, _process_connection_result_conn, context_to_airflow_vars, @@ -730,3 +731,37 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock map_index=0, ), ) + + +class TestAsyncGetConnection: + """Test async connection retrieval with secrets backends.""" + + @pytest.mark.asyncio + async def test_async_get_connection_from_secrets_backend(self, mock_supervisor_comms): + """Test that _async_get_connection successfully retrieves from secrets backend using sync_to_async.""" + sample_connection = Connection( + conn_id="test_conn", conn_type="postgres", host="localhost", port=5432, login="user" + ) + + class MockSecretsBackend: + """Simple mock secrets backend for testing.""" + + def __init__(self, connections: dict[str, Connection | None] | None = None): + self.connections = connections or {} + + def get_connection(self, conn_id: str) -> Connection | None: + return self.connections.get(conn_id) + + backend = MockSecretsBackend({"test_conn": sample_connection}) + + with patch( + "airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", autospec=True + ) as mock_load: + mock_load.return_value = [backend] + + result = await _async_get_connection("test_conn") + + assert result == sample_connection + # Should not have tried SUPERVISOR_COMMS since secrets backend had the connection + mock_supervisor_comms.send.assert_not_called() + mock_supervisor_comms.asend.assert_not_called()