diff --git a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst index c2201921cd135..00eca8bd0e819 100644 --- a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst +++ b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst @@ -20,19 +20,19 @@ Callbacks ========= -A valuable component of logging and monitoring is the use of task callbacks to act upon changes in state of a given DAG or task, or across all tasks in a given DAG. -For example, you may wish to alert when certain tasks have failed, or invoke a callback when your DAG succeeds. +A valuable component of logging and monitoring is the use of task callbacks to act upon changes in state of a given Dag or task, or across all tasks in a given Dag. +For example, you may wish to alert when certain tasks have failed, or invoke a callback when your Dag succeeds. There are three different places where callbacks can be defined. -- Callbacks set in the DAG definition will be applied at the DAG level. -- Using ``default_args``, callbacks can be set for each task in a DAG. +- Callbacks set in the Dag definition will be applied at the Dag level. +- Using ``default_args``, callbacks can be set for each task in a Dag. - Individual callbacks can be set for a task by setting that callback within the task definition itself. .. note:: - Callback functions are only invoked when the DAG or task state changes due to execution by a worker. - As such, DAG and task changes set by the command line interface (:doc:`CLI <../../howto/usage-cli>`) or user interface (:doc:`UI <../../ui>`) do not + Callback functions are only invoked when the Dag or task state changes due to execution by a worker. + As such, Dag and task changes set by the command line interface (:doc:`CLI <../../howto/usage-cli>`) or user interface (:doc:`UI <../../ui>`) do not execute callback functions. .. warning:: @@ -42,6 +42,12 @@ There are three different places where callbacks can be defined. By default, scheduler logs do not show up in the UI and instead can be found in ``$AIRFLOW_HOME/logs/scheduler/latest/DAG_FILE.py.log`` +.. note:: + As of Airflow 2.6.0, callbacks now supports a list of callback functions, allowing users to specify multiple functions + to be executed in the desired event. Simply pass a list of callback functions to the callback args when defining your Dag/task + callbacks: e.g ``on_failure_callback=[callback_func_1, callback_func_2]`` + + Callback Types -------------- @@ -50,33 +56,33 @@ There are six types of events that can trigger a callback: =========================================== ================================================================ Name Description =========================================== ================================================================ -``on_success_callback`` Invoked when the :ref:`DAG succeeds ` or :ref:`task succeeds `. - Available at the DAG or task level. +``on_success_callback`` Invoked when the :ref:`Dag succeeds ` or :ref:`task succeeds `. + Available at the Dag or task level. ``on_failure_callback`` Invoked when the task :ref:`fails `. - Available at the DAG or task level. + Available at the Dag or task level. ``on_retry_callback`` Invoked when the task is :ref:`up for retry `. Available only at the task level. ``on_execute_callback`` Invoked right before the task begins executing. Available only at the task level. ``on_skipped_callback`` Invoked when the task is :ref:`running ` and AirflowSkipException raised. Explicitly it is NOT called if a task is not started to be executed because of a preceding branching - decision in the DAG or a trigger rule which causes execution to skip so that the task execution + decision in the Dag or a trigger rule which causes execution to skip so that the task execution is never scheduled. Available only at the task level. =========================================== ================================================================ -Example -------- +Examples +-------- + +Using Custom Callback Methods +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -In the following example, failures in ``task1`` call the ``task_failure_alert`` function, and success at DAG level calls the ``dag_success_alert`` function. +In the following example, failures in ``task1`` call the ``task_failure_alert`` function, and success at Dag level calls the ``dag_success_alert`` function. Before each task begins to execute, the ``task_execute_callback`` function will be called: .. code-block:: python - import datetime - import pendulum - from airflow.sdk import DAG from airflow.providers.standard.operators.empty import EmptyOperator @@ -90,27 +96,48 @@ Before each task begins to execute, the ``task_execute_callback`` function will def dag_success_alert(context): - print(f"DAG has succeeded, run_id: {context['run_id']}") + print(f"Dag has succeeded, run_id: {context['run_id']}") with DAG( dag_id="example_callback", - schedule=None, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - dagrun_timeout=datetime.timedelta(minutes=60), - catchup=False, on_success_callback=dag_success_alert, default_args={"on_execute_callback": task_execute_callback}, - tags=["example"], ): task1 = EmptyOperator(task_id="task1", on_failure_callback=[task_failure_alert]) task2 = EmptyOperator(task_id="task2") task3 = EmptyOperator(task_id="task3") task1 >> task2 >> task3 -.. note:: - As of Airflow 2.6.0, callbacks now supports a list of callback functions, allowing users to specify multiple functions - to be executed in the desired event. Simply pass a list of callback functions to the callback args when defining your DAG/task - callbacks: e.g ``on_failure_callback=[callback_func_1, callback_func_2]`` - Full list of variables available in ``context`` in :doc:`docs <../../templates-ref>` and `code `_. + + +Using Notifiers +^^^^^^^^^^^^^^^ + +You can use Notifiers in your Dag definition by passing it as an argument to the ``on_*_callbacks``. +For example, you can use it with ``on_success_callback`` or ``on_failure_callback`` to send notifications based +on the status of a task or a Dag run. + +Here's an example of using a custom notifier: + +.. code-block:: python + + from airflow.sdk import DAG + from airflow.providers.standard.operators.bash import BashOperator + + from myprovider.notifier import MyNotifier + + with DAG( + dag_id="example_notifier", + on_success_callback=MyNotifier(message="Success!"), + on_failure_callback=MyNotifier(message="Failure!"), + ): + task = BashOperator( + task_id="example_task", + bash_command="exit 1", + on_success_callback=MyNotifier(message="Task Succeeded!"), + ) + +For a list of community-managed Notifiers, see :doc:`apache-airflow-providers:core-extensions/notifications`. +For more information on writing a custom Notifier, see the :doc:`Notifiers <../../howto/notifications>` how-to page. diff --git a/airflow-core/docs/howto/notifications.rst b/airflow-core/docs/howto/notifications.rst index 38a545bb4d8d9..1705ccf25dc10 100644 --- a/airflow-core/docs/howto/notifications.rst +++ b/airflow-core/docs/howto/notifications.rst @@ -15,8 +15,9 @@ specific language governing permissions and limitations under the License. -Creating a notifier +Creating a Notifier =================== + The :class:`~airflow.sdk.definitions.notifier.BaseNotifier` is an abstract class that provides a basic structure for sending notifications in Airflow using the various ``on_*__callback``. It is intended for providers to extend and customize for their specific needs. @@ -32,49 +33,29 @@ Here's an example of how you can create a Notifier class: .. code-block:: python from airflow.sdk import BaseNotifier - from my_provider import send_message + from my_provider import async_send_message, send_message class MyNotifier(BaseNotifier): template_fields = ("message",) - def __init__(self, message): + def __init__(self, message: str): self.message = message - def notify(self, context): - # Send notification here, below is an example + def notify(self, context: Context) -> None: + # Send notification here. For example: title = f"Task {context['task_instance'].task_id} failed" send_message(title, self.message) -Using a notifier ----------------- -Once you have a notifier implementation, you can use it in your ``DAG`` definition by passing it as an argument to -the ``on_*_callbacks``. For example, you can use it with ``on_success_callback`` or ``on_failure_callback`` to send -notifications based on the status of a task or a DAG run. - -Here's an example of using the above notifier: - -.. code-block:: python - - from datetime import datetime + async def async_notify(self, context: Context) -> None: + # Only required if your Notifier is going to support asynchronous code. For example: + title = f"Task {context['task_instance'].task_id} failed" + await async_send_message(title, self.message) - from airflow.sdk import DAG - from airflow.providers.standard.operators.bash import BashOperator - from myprovider.notifier import MyNotifier +For a list of community-managed notifiers, see :doc:`apache-airflow-providers:core-extensions/notifications`. - with DAG( - dag_id="example_notifier", - start_date=datetime(2022, 1, 1), - schedule=None, - on_success_callback=MyNotifier(message="Success!"), - on_failure_callback=MyNotifier(message="Failure!"), - ): - task = BashOperator( - task_id="example_task", - bash_command="exit 1", - on_success_callback=MyNotifier(message="Task Succeeded!"), - ) +Using Notifiers +=============== -For a list of community-managed notifiers, see -:doc:`apache-airflow-providers:core-extensions/notifications`. +For using Notifiers in event-based DAG callbacks, see :doc:`../administration-and-deployment/logging-monitoring/callbacks`. diff --git a/airflow-core/src/airflow/triggers/deadline.py b/airflow-core/src/airflow/triggers/deadline.py index 8b70015c76a19..bcff27fd1b2a3 100644 --- a/airflow-core/src/airflow/triggers/deadline.py +++ b/airflow-core/src/airflow/triggers/deadline.py @@ -51,7 +51,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: try: callback = import_string(self.callback_path) yield TriggerEvent({PAYLOAD_STATUS_KEY: DeadlineCallbackState.RUNNING}) - result = await callback(**self.callback_kwargs) + + # TODO: get airflow context + context: dict = {} + + result = await callback(**self.callback_kwargs, context=context) log.info("Deadline callback completed with return value: %s", result) yield TriggerEvent({PAYLOAD_STATUS_KEY: DeadlineCallbackState.SUCCESS, PAYLOAD_BODY_KEY: result}) except Exception as e: diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 981adf196bab1..73b68d76d0c45 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -21,7 +21,6 @@ import pytest import time_machine -from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from airflow.models import DagRun, Trigger @@ -36,9 +35,8 @@ from unit.models import DEFAULT_DATE DAG_ID = "dag_id_1" -RUN_ID = 1 INVALID_DAG_ID = "invalid_dag_id" -INVALID_RUN_ID = 2 +INVALID_RUN_ID = -1 REFERENCE_TYPES = [ pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"), @@ -77,6 +75,18 @@ def dagrun(session, dag_maker): return session.query(DagRun).one() +@pytest.fixture +def deadline_orm(dagrun, session): + deadline = Deadline( + deadline_time=DEFAULT_DATE, + callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + ) + session.add(deadline) + session.flush() + return deadline + + @pytest.mark.db_test class TestDeadline: @staticmethod @@ -87,42 +97,32 @@ def setup_method(): def teardown_method(): _clean_db() - def test_add_deadline(self, dagrun, session): - assert session.query(Deadline).count() == 0 - deadline_orm = Deadline( - deadline_time=DEFAULT_DATE, - callback=TEST_ASYNC_CALLBACK, - dagrun_id=dagrun.id, - ) - - session.add(deadline_orm) - session.flush() - - assert session.query(Deadline).count() == 1 - - result = session.scalars(select(Deadline)).first() - assert result.dagrun_id == deadline_orm.dagrun_id - assert result.deadline_time == deadline_orm.deadline_time - assert result.callback == deadline_orm.callback - @pytest.mark.parametrize( "conditions", [ pytest.param({}, id="empty_conditions"), - pytest.param({Deadline.dagrun_id: INVALID_RUN_ID}, id="no_matches"), - pytest.param({Deadline.dagrun_id: RUN_ID}, id="single_condition"), + pytest.param({Deadline.dagrun_id: -1}, id="no_matches"), + pytest.param({Deadline.dagrun_id: "valid_placeholder"}, id="single_condition"), pytest.param( - {Deadline.dagrun_id: RUN_ID, Deadline.deadline_time: datetime.now() + timedelta(days=365)}, + { + Deadline.dagrun_id: "valid_placeholder", + Deadline.deadline_time: datetime.now() + timedelta(days=365), + }, id="multiple_conditions", ), pytest.param( - {Deadline.dagrun_id: RUN_ID, Deadline.callback_state: "invalid"}, id="mixed_conditions" + {Deadline.dagrun_id: "valid_placeholder", Deadline.callback_state: "invalid"}, + id="mixed_conditions", ), ], ) @mock.patch("sqlalchemy.orm.Session") - def test_prune_deadlines(self, mock_session, conditions): + def test_prune_deadlines(self, mock_session, conditions, dagrun): """Test deadline resolution with various conditions.""" + if Deadline.dagrun_id in conditions: + if conditions[Deadline.dagrun_id] == "valid_placeholder": + conditions[Deadline.dagrun_id] = dagrun.id + expected_result = 1 if conditions else 0 # Set up the query chain to return a list of (Deadline, DagRun) pairs mock_dagrun = mock.Mock(spec=DagRun, end_date=datetime.now()) @@ -142,32 +142,13 @@ def test_prune_deadlines(self, mock_session, conditions): else: mock_session.query.assert_not_called() - def test_orm(self): - deadline_orm = Deadline( - deadline_time=DEFAULT_DATE, - callback=TEST_ASYNC_CALLBACK, - dagrun_id=RUN_ID, - ) - - assert deadline_orm.deadline_time == DEFAULT_DATE - assert deadline_orm.callback == TEST_ASYNC_CALLBACK - assert deadline_orm.dagrun_id == RUN_ID - - def test_repr_with_callback_kwargs(self, dagrun, session): - deadline_orm = Deadline( - deadline_time=DEFAULT_DATE, - callback=TEST_ASYNC_CALLBACK, - dagrun_id=dagrun.id, - ) - session.add(deadline_orm) - session.flush() - + def test_repr_with_callback_kwargs(self, deadline_orm, dagrun): assert ( repr(deadline_orm) == f"[DagRun Deadline] Dag: {DAG_ID} Run: {dagrun.id} needed by " - f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})" + f"{DEFAULT_DATE} or run: {TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})" ) - def test_repr_without_callback_kwargs(self, dagrun, session): + def test_repr_without_callback_kwargs(self, deadline_orm, dagrun, session): deadline_orm = Deadline( deadline_time=DEFAULT_DATE, callback=AsyncCallback(TEST_CALLBACK_PATH), @@ -179,19 +160,11 @@ def test_repr_without_callback_kwargs(self, dagrun, session): assert deadline_orm.callback.kwargs is None assert ( repr(deadline_orm) == f"[DagRun Deadline] Dag: {DAG_ID} Run: {dagrun.id} needed by " - f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}()" + f"{DEFAULT_DATE} or run: {TEST_CALLBACK_PATH}()" ) @pytest.mark.db_test - def test_handle_miss_async_callback(self, dagrun, session): - deadline_orm = Deadline( - deadline_time=DEFAULT_DATE, - callback=TEST_ASYNC_CALLBACK, - dagrun_id=dagrun.id, - ) - session.add(deadline_orm) - session.flush() - + def test_handle_miss_async_callback(self, dagrun, deadline_orm, session): deadline_orm.handle_miss(session=session) session.flush() @@ -248,15 +221,7 @@ def test_handle_miss_sync_callback(self, dagrun, session): pytest.param(TriggerEvent({PAYLOAD_STATUS_KEY: "unknown_state"}), False, id="unknown_event"), ], ) - def test_handle_callback_event(self, dagrun, session, event, none_trigger_expected): - deadline_orm = Deadline( - deadline_time=DEFAULT_DATE, - callback=TEST_ASYNC_CALLBACK, - dagrun_id=dagrun.id, - ) - session.add(deadline_orm) - session.flush() - + def test_handle_callback_event(self, dagrun, deadline_orm, session, event, none_trigger_expected): deadline_orm.handle_miss(session=session) session.flush() @@ -271,6 +236,26 @@ def test_handle_callback_event(self, dagrun, session, event, none_trigger_expect else: assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED + def test_handle_miss_creates_trigger(self, dagrun, deadline_orm, session): + """Test that handle_miss creates a trigger with correct parameters.""" + deadline_orm.handle_miss(session) + session.flush() + + # Check trigger was created + trigger = session.query(Trigger).first() + assert trigger is not None + assert deadline_orm.trigger_id == trigger.id + + # Check trigger has correct kwargs + assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH + assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS + + def test_handle_miss_sets_callback_state(self, dagrun, deadline_orm, session): + """Test that handle_miss sets the callback state to QUEUED.""" + deadline_orm.handle_miss(session) + + assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED + @pytest.mark.db_test class TestCalculatedDeadlineDatabaseCalls: diff --git a/airflow-core/tests/unit/triggers/test_deadline.py b/airflow-core/tests/unit/triggers/test_deadline.py index 137f40e63e7f4..72bea33f1885f 100644 --- a/airflow-core/tests/unit/triggers/test_deadline.py +++ b/airflow-core/tests/unit/triggers/test_deadline.py @@ -22,13 +22,29 @@ import pytest from airflow.models.deadline import DeadlineCallbackState +from airflow.sdk import BaseNotifier from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger +TEST_MESSAGE = "test_message" TEST_CALLBACK_PATH = "classpath.test_callback_for_deadline" -TEST_CALLBACK_KWARGS = {"arg1": "value1"} +TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE} TEST_TRIGGER = DeadlineCallbackTrigger(callback_path=TEST_CALLBACK_PATH, callback_kwargs=TEST_CALLBACK_KWARGS) +class ExampleAsyncNotifier(BaseNotifier): + """Example of a properly implemented async notifier.""" + + def __init__(self, message, **kwargs): + super().__init__(**kwargs) + self.message = message + + async def async_notify(self, context): + return f"Async notification: {self.message}, context: {context}" + + def notify(self, context): + return f"Sync notification: {self.message}, context: {context}" + + class TestDeadlineCallbackTrigger: @pytest.fixture def mock_import_string(self): @@ -56,7 +72,8 @@ def test_serialization(self, callback_init_kwargs, expected_serialized_kwargs): } @pytest.mark.asyncio - async def test_run_success(self, mock_import_string): + async def test_run_success_with_async_function(self, mock_import_string): + """Test trigger handles async functions correctly.""" callback_return_value = "some value" mock_callback = mock.AsyncMock(return_value=callback_return_value) mock_import_string.return_value = mock_callback @@ -68,10 +85,25 @@ async def test_run_success(self, mock_import_string): success_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) - mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) + mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS, context=mock.ANY) assert success_event.payload[PAYLOAD_STATUS_KEY] == DeadlineCallbackState.SUCCESS assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value + @pytest.mark.asyncio + async def test_run_success_with_notifier(self, mock_import_string): + """Test trigger handles async notifier classes correctly.""" + mock_import_string.return_value = ExampleAsyncNotifier + + trigger_gen = TEST_TRIGGER.run() + + running_event = await anext(trigger_gen) + assert running_event.payload[PAYLOAD_STATUS_KEY] == DeadlineCallbackState.RUNNING + + success_event = await anext(trigger_gen) + mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) + assert success_event.payload[PAYLOAD_STATUS_KEY] == DeadlineCallbackState.SUCCESS + assert success_event.payload[PAYLOAD_BODY_KEY] == f"Async notification: {TEST_MESSAGE}, context: {{}}" + @pytest.mark.asyncio async def test_run_failure(self, mock_import_string): exc_msg = "Something went wrong" @@ -85,6 +117,6 @@ async def test_run_failure(self, mock_import_string): failure_event = await anext(trigger_gen) mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH) - mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS) + mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS, context=mock.ANY) assert failure_event.payload[PAYLOAD_STATUS_KEY] == DeadlineCallbackState.FAILED assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in ["raise", "RuntimeError", exc_msg]) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 848a633403e89..dc98efe31e748 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2127,7 +2127,10 @@ def test_invalid_location(self, sdk_connection_not_found): from airflow.sdk.exceptions import ErrorType from airflow.sdk.execution_time.comms import ErrorResponse - mock_supervisor_comms.send.return_value = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + error_response = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) + mock_supervisor_comms.send.return_value = error_response + if hasattr(mock_supervisor_comms, "asend"): + mock_supervisor_comms.asend.return_value = error_response yield mock_supervisor_comms diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py index 2e6a0f74ae197..9c2260b2936ee 100644 --- a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any from slack_sdk import WebhookClient +from slack_sdk.webhook.async_client import AsyncWebhookClient from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.slack.utils import ConnectionExtraConfig @@ -35,17 +36,34 @@ LEGACY_INTEGRATION_PARAMS = ("channel", "username", "icon_emoji", "icon_url") +def _validate_response(resp): + """Validate webhook response and raise error if status code != 200.""" + if resp.status_code != 200: + raise AirflowException( + f"Response body: {resp.body!r}, Status Code: {resp.status_code}. " + "See: https://api.slack.com/messaging/webhooks#handling_errors" + ) + + def check_webhook_response(func: Callable) -> Callable: """Check WebhookResponse and raise an error if status code != 200.""" @wraps(func) def wrapper(*args, **kwargs) -> Callable: resp = func(*args, **kwargs) - if resp.status_code != 200: - raise AirflowException( - f"Response body: {resp.body!r}, Status Code: {resp.status_code}. " - "See: https://api.slack.com/messaging/webhooks#handling_errors" - ) + _validate_response(resp) + return resp + + return wrapper + + +def async_check_webhook_response(func: Callable) -> Callable: + """Check WebhookResponse and raise an error if status code != 200 (async).""" + + @wraps(func) + async def wrapper(*args, **kwargs) -> Callable: + resp = await func(*args, **kwargs) + _validate_response(resp) return resp return wrapper @@ -134,13 +152,27 @@ def client(self) -> WebhookClient: """Get the underlying slack_sdk.webhook.WebhookClient (cached).""" return WebhookClient(**self._get_conn_params()) + @cached_property + async def async_client(self) -> AsyncWebhookClient: + """Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient` (cached).""" + return AsyncWebhookClient(**await self._async_get_conn_params()) + def get_conn(self) -> WebhookClient: - """Get the underlying slack_sdk.webhook.WebhookClient (cached).""" + """Get the underlying `slack_sdk.webhook.WebhookClient` (cached).""" return self.client def _get_conn_params(self) -> dict[str, Any]: """Fetch connection params as a dict and merge it with hook parameters.""" conn = self.get_connection(self.slack_webhook_conn_id) + return self._build_conn_params(conn) + + async def _async_get_conn_params(self) -> dict[str, Any]: + """Fetch connection params as a dict and merge it with hook parameters (async).""" + conn = await self.aget_connection(self.slack_webhook_conn_id) + return self._build_conn_params(conn) + + def _build_conn_params(self, conn) -> dict[str, Any]: + """Build connection parameters from connection object.""" if not conn.password or not conn.password.strip(): raise AirflowNotFoundException( f"Connection ID {self.slack_webhook_conn_id!r} does not contain password " @@ -173,14 +205,8 @@ def _get_conn_params(self) -> dict[str, Any]: conn_params.update(self.extra_client_args) return {k: v for k, v in conn_params.items() if v is not None} - @check_webhook_response - def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str] | None = None): - """ - Perform a Slack Incoming Webhook request with given JSON data block. - - :param body: JSON data structure, expected dict or JSON-string. - :param headers: Request headers for this request. - """ + def _process_body(self, body: dict[str, Any] | str) -> dict[str, Any]: + """Validate and process the request body.""" if isinstance(body, str): try: body = json.loads(body) @@ -203,9 +229,31 @@ def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str] | Non UserWarning, stacklevel=2, ) + return body + @check_webhook_response + def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str] | None = None): + """ + Perform a Slack Incoming Webhook request with given JSON data block. + + :param body: JSON data structure, expected dict or JSON-string. + :param headers: Request headers for this request. + """ + body = self._process_body(body) return self.client.send_dict(body, headers=headers) + @async_check_webhook_response + async def async_send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str] | None = None): + """ + Perform a Slack Incoming Webhook request with given JSON data block (async). + + :param body: JSON data structure, expected dict or JSON-string. + :param headers: Request headers for this request. + """ + body = self._process_body(body) + async_client = await self.async_client + return await async_client.send_dict(body, headers=headers) + def send( self, *, @@ -235,20 +283,69 @@ def send( :param attachments: (legacy) A collection of attachments. """ body = { - "text": text, - "attachments": attachments, - "blocks": blocks, - "response_type": response_type, - "replace_original": replace_original, - "delete_original": delete_original, - "unfurl_links": unfurl_links, - "unfurl_media": unfurl_media, - # Legacy Integration Parameters - **kwargs, + k: v + for k, v in { + "text": text, + "attachments": attachments, + "blocks": blocks, + "response_type": response_type, + "replace_original": replace_original, + "delete_original": delete_original, + "unfurl_links": unfurl_links, + "unfurl_media": unfurl_media, + # Legacy Integration Parameters + **kwargs, + }.items() + if v is not None } - body = {k: v for k, v in body.items() if v is not None} return self.send_dict(body=body, headers=headers) + async def async_send( + self, + *, + text: str | None = None, + blocks: list[dict[str, Any]] | None = None, + response_type: str | None = None, + replace_original: bool | None = None, + delete_original: bool | None = None, + unfurl_links: bool | None = None, + unfurl_media: bool | None = None, + headers: dict[str, str] | None = None, + attachments: list[dict[str, Any]] | None = None, + **kwargs, + ): + """ + Perform a Slack Incoming Webhook request with given arguments (async). + + :param text: The text message + (even when having blocks, setting this as well is recommended as it works as fallback). + :param blocks: A collection of Block Kit UI components. + :param response_type: The type of message (either 'in_channel' or 'ephemeral'). + :param replace_original: True if you use this option for response_url requests. + :param delete_original: True if you use this option for response_url requests. + :param unfurl_links: Option to indicate whether text url should unfurl. + :param unfurl_media: Option to indicate whether media url should unfurl. + :param headers: Request headers for this request. + :param attachments: (legacy) A collection of attachments. + """ + body = { + k: v + for k, v in { + "text": text, + "attachments": attachments, + "blocks": blocks, + "response_type": response_type, + "replace_original": replace_original, + "delete_original": delete_original, + "unfurl_links": unfurl_links, + "unfurl_media": unfurl_media, + # Legacy Integration Parameters + **kwargs, + }.items() + if v is not None + } + return await self.async_send_dict(body=body, headers=headers) + def send_text( self, text: str, diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py index 36b7ccd851ddf..d1125224ecf58 100644 --- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py @@ -62,8 +62,9 @@ def __init__( timeout: int | None = None, attachments: list | None = None, retry_handlers: list[RetryHandler] | None = None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.slack_webhook_conn_id = slack_webhook_conn_id self.text = text self.attachments = attachments @@ -86,13 +87,29 @@ def hook(self) -> SlackWebhookHook: def notify(self, context): """Send a message to a Slack Incoming Webhook.""" - self.hook.send( + resp = self.hook.send( text=self.text, blocks=self.blocks, unfurl_links=self.unfurl_links, unfurl_media=self.unfurl_media, attachments=self.attachments, ) + self.log.debug( + "Slack webhook notification sent using notify(): %s %s", resp.status_code, resp.api_url + ) + + async def async_notify(self, context): + """Send a message to a Slack Incoming Webhook (async).""" + resp = await self.hook.async_send( + text=self.text, + blocks=self.blocks, + unfurl_links=self.unfurl_links, + unfurl_media=self.unfurl_media, + attachments=self.attachments, + ) + self.log.debug( + "Slack webhook notification sent using notify_async(): %s %s", resp.status_code, resp.api_url + ) send_slack_webhook_notification = SlackWebhookNotifier diff --git a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py index 8dfa021b9113f..43c97045b4043 100644 --- a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py +++ b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py @@ -27,11 +27,16 @@ import pytest from slack_sdk.http_retry.builtin_handlers import ConnectionErrorRetryHandler, RateLimitErrorRetryHandler +from slack_sdk.webhook.async_client import AsyncWebhookClient from slack_sdk.webhook.webhook_response import WebhookResponse from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.models.connection import Connection -from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook, check_webhook_response +from airflow.providers.slack.hooks.slack_webhook import ( + SlackWebhookHook, + async_check_webhook_response, + check_webhook_response, +) TEST_TOKEN = "T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX" TEST_WEBHOOK_URL = f"https://hooks.slack.com/services/{TEST_TOKEN}" @@ -172,6 +177,42 @@ def decorated(): assert decorated() +class TestAsyncCheckWebhookResponseDecorator: + @pytest.mark.asyncio + async def test_ok_response(self): + """Test async decorator with OK response.""" + + @async_check_webhook_response + async def decorated(): + return MOCK_WEBHOOK_RESPONSE + + assert await decorated() is MOCK_WEBHOOK_RESPONSE + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "status_code,body", + [ + (400, "invalid_payload"), + (403, "action_prohibited"), + (404, "channel_not_found"), + (410, "channel_is_archived"), + (500, "rollup_error"), + (418, "i_am_teapot"), + ], + ) + async def test_error_response(self, status_code, body): + """Test async decorator with error response.""" + test_response = WebhookResponse(url="foo://bar", status_code=status_code, body=body, headers={}) + + @async_check_webhook_response + async def decorated(): + return test_response + + error_message = rf"Response body: '{body}', Status Code: {status_code}\." + with pytest.raises(AirflowException, match=error_message): + await decorated() + + class TestSlackWebhookHook: @pytest.mark.parametrize( "conn_id", @@ -432,6 +473,7 @@ def test_hook_send_dict_legacy_slack_integration(self, mock_webhook_client_cls, {"text": "Test Text"}, {"text": "Fallback Text", "blocks": ["Dummy Block"]}, {"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True}, + {"legacy": "value"}, ], ) @mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook.send_dict") @@ -503,3 +545,92 @@ def test_empty_string_ignored_non_prefixed(self): hook = SlackWebhookHook(slack_webhook_conn_id="my_conn") params = hook._get_conn_params() assert "proxy" not in params + + +class TestSlackWebhookHookAsync: + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params") + async def test_async_client(self, mock_async_get_conn_params): + """Test async_client property creates AsyncWebhookClient with correct params.""" + mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL} + + hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID) + client = await hook.async_client + + assert isinstance(client, AsyncWebhookClient) + assert client.url == TEST_WEBHOOK_URL + mock_async_get_conn_params.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}]) + @pytest.mark.parametrize( + "send_body", + [ + {"text": "Test Text"}, + {"text": "Fallback Text", "blocks": ["Dummy Block"]}, + {"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True}, + ], + ) + @mock.patch("airflow.providers.slack.hooks.slack_webhook.AsyncWebhookClient") + @mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params") + async def test_async_send_dict( + self, mock_async_get_conn_params, mock_async_webhook_client_cls, send_body, headers + ): + """Test async_send_dict method with dict input.""" + mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL} + mock_async_client = mock_async_webhook_client_cls.return_value + mock_async_client.send_dict = mock.AsyncMock(return_value=MOCK_WEBHOOK_RESPONSE) + + hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID) + resp = await hook.async_send_dict(body=send_body, headers=headers) + + assert resp == MOCK_WEBHOOK_RESPONSE + mock_async_client.send_dict.assert_called_once_with(send_body, headers=headers) + + @pytest.mark.asyncio + @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}]) + @pytest.mark.parametrize( + "send_body", + [ + {"text": "Test Text"}, + {"text": "Fallback Text", "blocks": ["Dummy Block"]}, + {"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True}, + ], + ) + @mock.patch("airflow.providers.slack.hooks.slack_webhook.AsyncWebhookClient") + @mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params") + async def test_async_send_dict_json_string( + self, mock_async_get_conn_params, mock_async_webhook_client_cls, send_body, headers + ): + """Test async_send_dict method with JSON string input.""" + mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL} + mock_async_client = mock_async_webhook_client_cls.return_value + mock_async_client.send_dict = mock.AsyncMock(return_value=MOCK_WEBHOOK_RESPONSE) + + hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID) + resp = await hook.async_send_dict(body=json.dumps(send_body), headers=headers) + + assert resp == MOCK_WEBHOOK_RESPONSE + mock_async_client.send_dict.assert_called_once_with(send_body, headers=headers) + + @pytest.mark.asyncio + @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}]) + @pytest.mark.parametrize( + "send_params", + [ + {"text": "Test Text"}, + {"text": "Fallback Text", "blocks": ["Dummy Block"]}, + {"text": "Fallback Text", "blocks": ["Dummy Block"], "unfurl_media": True, "unfurl_links": True}, + {"legacy": "value"}, + ], + ) + @mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook.async_send_dict") + async def test_async_send(self, mock_async_send_dict, send_params, headers): + """Test at async_send method.""" + mock_async_send_dict.return_value = MOCK_WEBHOOK_RESPONSE + + hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID) + resp = await hook.async_send(**send_params, headers=headers) + + assert resp == MOCK_WEBHOOK_RESPONSE + mock_async_send_dict.assert_called_once_with(body=send_params, headers=headers) diff --git a/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py b/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py index 7897ef2674bb2..ed12eb5794bef 100644 --- a/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py +++ b/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py @@ -65,6 +65,31 @@ def test_slack_webhook_notifier(self, mock_slack_hook, slack_op_kwargs, hook_ext ) mock_slack_hook.assert_called_once_with(slack_webhook_conn_id="test_conn_id", **hook_extra_kwargs) + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook") + async def test_async_slack_webhook_notifier(self, mock_slack_hook): + mock_hook = mock_slack_hook.return_value + mock_hook.async_send = mock.AsyncMock() + + notifier = send_slack_webhook_notification( + slack_webhook_conn_id="test_conn_id", + text="foo-bar", + blocks="spam-egg", + attachments="baz-qux", + unfurl_links=True, + unfurl_media=False, + ) + + await notifier.async_notify({}) + + mock_hook.async_send.assert_called_once_with( + text="foo-bar", + blocks="spam-egg", + unfurl_links=True, + unfurl_media=False, + attachments="baz-qux", + ) + @mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook") def test_slack_webhook_templated(self, mock_slack_hook, create_dag_without_db): notifier = send_slack_webhook_notification( @@ -90,3 +115,48 @@ def test_slack_webhook_templated(self, mock_slack_hook, create_dag_without_db): unfurl_links=None, unfurl_media=None, ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook") + async def test_async_slack_webhook_templated(self, mock_slack_hook, create_dag_without_db): + """Test async notification with template rendering.""" + mock_hook = mock_slack_hook.return_value + mock_hook.async_send = mock.AsyncMock() + + notifier = send_slack_webhook_notification( + text="Who am I? {{ username }}", + blocks=[{"type": "header", "text": {"type": "plain_text", "text": "{{ dag.dag_id }}"}}], + attachments=[{"image_url": "{{ dag.dag_id }}.png"}], + ) + + # Call notifier first to handle template rendering + notifier( + { + "dag": create_dag_without_db("test_async_send_slack_webhook_notification_templated"), + "username": "not-a-root", + } + ) + + # Then call async_notify with rendered templates + await notifier.async_notify( + { + "dag": create_dag_without_db("test_async_send_slack_webhook_notification_templated"), + "username": "not-a-root", + } + ) + + mock_hook.async_send.assert_called_once_with( + text="Who am I? not-a-root", + blocks=[ + { + "type": "header", + "text": { + "type": "plain_text", + "text": "test_async_send_slack_webhook_notification_templated", + }, + } + ], + attachments=[{"image_url": "test_async_send_slack_webhook_notification_templated.png"}], + unfurl_links=None, + unfurl_media=None, + ) diff --git a/task-sdk/src/airflow/sdk/bases/hook.py b/task-sdk/src/airflow/sdk/bases/hook.py index 2de50c443c5a9..8aaba17ca7814 100644 --- a/task-sdk/src/airflow/sdk/bases/hook.py +++ b/task-sdk/src/airflow/sdk/bases/hook.py @@ -62,6 +62,20 @@ def get_connection(cls, conn_id: str) -> Connection: log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id) return conn + @classmethod + async def aget_connection(cls, conn_id: str) -> Connection: + """ + Get connection (async), given connection id. + + :param conn_id: connection id + :return: connection + """ + from airflow.sdk.definitions.connection import Connection + + conn = await Connection.async_get(conn_id) + log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id) + return conn + @classmethod def get_hook(cls, conn_id: str, hook_params: dict | None = None): """ diff --git a/task-sdk/src/airflow/sdk/bases/notifier.py b/task-sdk/src/airflow/sdk/bases/notifier.py index df4023d043a23..6772e406f0bd4 100644 --- a/task-sdk/src/airflow/sdk/bases/notifier.py +++ b/task-sdk/src/airflow/sdk/bases/notifier.py @@ -17,8 +17,7 @@ from __future__ import annotations -from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Generator, Sequence from typing import TYPE_CHECKING from airflow.sdk.definitions._internal.templater import Templater @@ -33,13 +32,32 @@ class BaseNotifier(LoggingMixin, Templater): - """BaseNotifier class for sending notifications.""" + """ + BaseNotifier class for sending notifications. + + It can be used asynchronously (preferred) if `async_notify`is implemented and/or + synchronously if `notify` is implemented. + + Currently, the DAG/Task state change callbacks run on the DAG Processor and only support sync usage. + + Usage:: + # Asynchronous usage + await Notifier(context) + + # Synchronous usage + notifier = Notifier() + notifier(context) + """ template_fields: Sequence[str] = () template_ext: Sequence[str] = () - def __init__(self): + # Context stored as attribute here because parameters can't be passed to __await__ + context: Context + + def __init__(self, context: Context | None = None): super().__init__() + self.context = context or {} self.resolve_template_files() def _update_context(self, context: Context) -> Context: @@ -53,7 +71,7 @@ def _update_context(self, context: Context) -> Context: return context def _render(self, template, context, dag: DAG | None = None): - dag = dag or context["dag"] + dag = dag or context.get("dag") return super()._render(template, context, dag) def render_template_fields( @@ -69,19 +87,34 @@ def render_template_fields( :param context: Context dict with values to apply on content. :param jinja_env: Jinja environment to use for rendering. """ - dag = context["dag"] + dag = context.get("dag") if not jinja_env: jinja_env = self.get_template_env(dag=dag) self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) - @abstractmethod + async def async_notify(self, context: Context) -> None: + """ + Send a notification (async). + + Implementing this is a requirement for running this notifier in the triggerer, which is the + recommended approach for using Deadline Alerts. + + :param context: The airflow context + + Note: the context is not available in the current version. + """ + raise NotImplementedError + def notify(self, context: Context) -> None: """ - Send a notification. + Send a notification (sync). + + Implementing this is a requirement for running this notifier in the DAG processor, which is where the + `on_success_callback` and `on_failure_callback` run. :param context: The airflow context """ - ... + raise NotImplementedError def __call__(self, *args) -> None: """ @@ -104,4 +137,19 @@ def __call__(self, *args) -> None: try: self.notify(context) except Exception as e: - self.log.exception("Failed to send notification: %s", e) + self.log.error("Failed to send notification (sync): %s", e) + raise + + def __await__(self) -> Generator: + """ + Make the notifier awaitable. + + Context must be provided as an attribute. + """ + self._update_context(self.context) + self.render_template_fields(self.context) + try: + return self.async_notify(self.context).__await__() + except Exception as e: + self.log.error("Failed to send notification (async): %s", e) + raise diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py b/task-sdk/src/airflow/sdk/definitions/connection.py index 39ea645395a2e..cc2e92a41aa1a 100644 --- a/task-sdk/src/airflow/sdk/definitions/connection.py +++ b/task-sdk/src/airflow/sdk/definitions/connection.py @@ -188,6 +188,13 @@ def get_hook(self, *, hook_params=None): hook_params = {} return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params) + @classmethod + def _handle_connection_error(cls, e: AirflowRuntimeError, conn_id: str) -> None: + """Handle connection retrieval errors.""" + if e.error.error == ErrorType.CONNECTION_NOT_FOUND: + raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") from None + raise + @classmethod def get(cls, conn_id: str) -> Any: from airflow.sdk.execution_time.context import _get_connection @@ -195,9 +202,16 @@ def get(cls, conn_id: str) -> Any: try: return _get_connection(conn_id) except AirflowRuntimeError as e: - if e.error.error == ErrorType.CONNECTION_NOT_FOUND: - raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") from None - raise + cls._handle_connection_error(e, conn_id) + + @classmethod + async def async_get(cls, conn_id: str) -> Any: + from airflow.sdk.execution_time.context import _async_get_connection + + try: + return await _async_get_connection(conn_id) + except AirflowRuntimeError as e: + cls._handle_connection_error(e, conn_id) @property def extra_dejson(self) -> dict: diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index f2ee0672299a2..27e4b0a754e65 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -206,6 +206,10 @@ def send(self, msg: SendMsgType) -> ReceiveMsgType | None: return self._get_response() + async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None: + """Send a request to the parent without blocking.""" + raise NotImplementedError + @overload def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ... diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index caf586cb0182b..570cd25d9a3ef 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -56,6 +56,7 @@ ConnectionResult, OKResponse, PrevSuccessfulDagRunResponse, + ReceiveMsgType, VariableResult, ) from airflow.sdk.types import OutletEventAccessorsProtocol @@ -101,8 +102,15 @@ T = TypeVar("T") -def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection: +def _process_connection_result_conn(conn_result: ReceiveMsgType | None) -> Connection: from airflow.sdk.definitions.connection import Connection + from airflow.sdk.execution_time.comms import ErrorResponse + + if isinstance(conn_result, ErrorResponse): + raise AirflowRuntimeError(conn_result) + + if TYPE_CHECKING: + assert isinstance(conn_result, ConnectionResult) # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) @@ -121,7 +129,7 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize def _get_connection(conn_id: str) -> Connection: from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - # TODO: check cache first + # TODO: check cache first (also in _async_get_connection) # enabled only if SecretCache.init() has been called first # iterate over configured backends if not in cache (or expired) @@ -154,17 +162,24 @@ def _get_connection(conn_id: str) -> Connection: # A reason to not move it to `airflow.sdk.execution_time.comms` is that it # will make that module depend on Task SDK, which is not ideal because we intend to # keep Task SDK as a separate package than execution time mods. - from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection + # Also applies to _async_get_connection. + from airflow.sdk.execution_time.comms import GetConnection from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) - if isinstance(msg, ErrorResponse): - raise AirflowRuntimeError(msg) + return _process_connection_result_conn(msg) - if TYPE_CHECKING: - assert isinstance(msg, ConnectionResult) - return _convert_connection_result_conn(msg) + +async def _async_get_connection(conn_id: str) -> Connection: + # TODO: add async support for secrets backends + + from airflow.sdk.execution_time.comms import GetConnection + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) + + return _process_connection_result_conn(msg) def _get_variable(key: str, deserialize_json: bool) -> Any: diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py b/task-sdk/tests/task_sdk/bases/test_hook.py index 4b15ab5e0139a..4c691266f8205 100644 --- a/task-sdk/tests/task_sdk/bases/test_hook.py +++ b/task-sdk/tests/task_sdk/bases/test_hook.py @@ -59,6 +59,28 @@ def test_get_connection(self, mock_supervisor_comms): msg=GetConnection(conn_id="test_conn"), ) + @pytest.mark.asyncio + async def test_aget_connection(self, mock_supervisor_comms): + """Test async connection retrieval in task sdk context.""" + conn = ConnectionResult( + conn_id="test_conn", + conn_type="mysql", + host="mysql", + schema="airflow", + login="login", + password="password", + port=1234, + extra='{"extra_key": "extra_value"}', + ) + + mock_supervisor_comms.asend.return_value = conn + + hook = BaseHook(logger_name="") + await hook.aget_connection(conn_id="test_conn") + mock_supervisor_comms.asend.assert_called_once_with( + msg=GetConnection(conn_id="test_conn"), + ) + def test_get_connection_not_found(self, sdk_connection_not_found): conn_id = "test_conn" hook = BaseHook() @@ -67,6 +89,16 @@ def test_get_connection_not_found(self, sdk_connection_not_found): with pytest.raises(AirflowNotFoundException, match="The conn_id `test_conn` isn't defined"): hook.get_connection(conn_id=conn_id) + @pytest.mark.asyncio + async def test_aget_connection_not_found(self, sdk_connection_not_found): + """Test async connection not found error.""" + conn_id = "test_conn" + hook = BaseHook() + sdk_connection_not_found + + with pytest.raises(AirflowNotFoundException, match="The conn_id `test_conn` isn't defined"): + await hook.aget_connection(conn_id=conn_id) + def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, tmp_path): path = tmp_path / "conn.env" path.write_text("CONN_A=mysql://host_a") diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index 1def08086cce3..54e2c66bee82f 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -58,8 +58,8 @@ TriggeringAssetEventsAccessor, VariableAccessor, _AssetRefResolutionMixin, - _convert_connection_result_conn, _convert_variable_result_to_variable, + _process_connection_result_conn, context_to_airflow_vars, set_current_context, ) @@ -77,7 +77,7 @@ def test_convert_connection_result_conn(): port=1234, extra='{"extra_key": "extra_value"}', ) - conn = _convert_connection_result_conn(conn) + conn = _process_connection_result_conn(conn) assert conn == Connection( conn_id="test_conn", conn_type="mysql",