Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions task-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ classifiers = [
]
dependencies = [
"apache-airflow-core<3.3.0,>=3.2.0",
"asgiref>=2.3.0",
"attrs>=24.2.0, !=25.2.0",
"fsspec>=2023.10.0",
"httpx>=0.27.0",
Expand Down
24 changes: 22 additions & 2 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,33 @@ def _get_connection(conn_id: str) -> Connection:


async def _async_get_connection(conn_id: str) -> Connection:
# TODO: add async support for secrets backends
from asgiref.sync import sync_to_async

from airflow.sdk.execution_time.comms import GetConnection
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))
# TODO: check cache first
# enabled only if SecretCache.init() has been called first

# Try secrets backends first using async wrapper
backends = ensure_secrets_backend_loaded()
for secrets_backend in backends:
try:
conn = await sync_to_async(secrets_backend.get_connection)(conn_id)
if conn:
# TODO: this should probably be in get conn
if conn.password:
mask_secret(conn.password)
if conn.extra:
mask_secret(conn.extra)
return conn
except Exception:
# If one backend fails, try the next one
continue

# If no secrets backend has the connection, fall back to API server
msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))
return _process_connection_result_conn(msg)


Expand Down
35 changes: 35 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
TriggeringAssetEventsAccessor,
VariableAccessor,
_AssetRefResolutionMixin,
_async_get_connection,
_convert_variable_result_to_variable,
_process_connection_result_conn,
context_to_airflow_vars,
Expand Down Expand Up @@ -730,3 +731,37 @@ def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock
map_index=0,
),
)


class TestAsyncGetConnection:
"""Test async connection retrieval with secrets backends."""

@pytest.mark.asyncio
async def test_async_get_connection_from_secrets_backend(self, mock_supervisor_comms):
"""Test that _async_get_connection successfully retrieves from secrets backend using sync_to_async."""
sample_connection = Connection(
conn_id="test_conn", conn_type="postgres", host="localhost", port=5432, login="user"
)

class MockSecretsBackend:
"""Simple mock secrets backend for testing."""

def __init__(self, connections: dict[str, Connection | None] | None = None):
self.connections = connections or {}

def get_connection(self, conn_id: str) -> Connection | None:
return self.connections.get(conn_id)

backend = MockSecretsBackend({"test_conn": sample_connection})

with patch(
"airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", autospec=True
) as mock_load:
mock_load.return_value = [backend]

result = await _async_get_connection("test_conn")

assert result == sample_connection
# Should not have tried SUPERVISOR_COMMS since secrets backend had the connection
mock_supervisor_comms.send.assert_not_called()
mock_supervisor_comms.asend.assert_not_called()