diff --git a/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py b/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py index 84047c8ec05cb..6e827bf6e1d4a 100644 --- a/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py +++ b/providers/apprise/src/airflow/providers/apprise/hooks/apprise.py @@ -24,11 +24,14 @@ import apprise from apprise import AppriseConfig, NotifyFormat, NotifyType +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import BaseHook if TYPE_CHECKING: from apprise import AppriseAttachment + from airflow.providers.common.compat.sdk import Connection + class AppriseHook(BaseHook): """ @@ -50,14 +53,13 @@ def __init__(self, apprise_conn_id: str = default_conn_name) -> None: super().__init__() self.apprise_conn_id = apprise_conn_id - def get_config_from_conn(self): - conn = self.get_connection(self.apprise_conn_id) + def get_config_from_conn(self, conn: Connection): config = conn.extra_dejson["config"] return json.loads(config) if isinstance(config, str) else config - def set_config_from_conn(self, apprise_obj: apprise.Apprise): + def set_config_from_conn(self, conn: Connection, apprise_obj: apprise.Apprise): """Set config from connection to apprise object.""" - config_object = self.get_config_from_conn() + config_object = self.get_config_from_conn(conn=conn) if isinstance(config_object, list): for config in config_object: apprise_obj.add(config["path"], tag=config.get("tag", None)) @@ -101,7 +103,8 @@ def notify( if config: apprise_obj.add(config) else: - self.set_config_from_conn(apprise_obj) + conn = self.get_connection(self.apprise_conn_id) + self.set_config_from_conn(conn=conn, apprise_obj=apprise_obj) apprise_obj.notify( body=body, title=title, @@ -112,6 +115,50 @@ def notify( interpret_escapes=interpret_escapes, ) + async def async_notify( + self, + body: str, + title: str | None = None, + notify_type: NotifyType = NotifyType.INFO, + body_format: NotifyFormat = NotifyFormat.TEXT, + tag: str | Iterable[str] = "all", + attach: AppriseAttachment | None = None, + interpret_escapes: bool | None = None, + config: AppriseConfig | None = None, + ): + r""" + Send message to plugged-in services asynchronously. + + :param body: Specify the message body + :param title: Specify the message title. (optional) + :param notify_type: Specify the message type (default=info). Possible values are "info", + "success", "failure", and "warning" + :param body_format: Specify the input message format (default=text). Possible values are "text", + "html", and "markdown". + :param tag: Specify one or more tags to filter which services to notify + :param attach: Specify one or more file attachment locations + :param interpret_escapes: Enable interpretation of backslash escapes. For example, this would convert + sequences such as \n and \r to their respective ascii new-line and carriage return characters + :param config: Specify one or more configuration + """ + title = title or "" + + apprise_obj = apprise.Apprise() + if config: + apprise_obj.add(config) + else: + conn = await get_async_connection(self.apprise_conn_id) + self.set_config_from_conn(conn=conn, apprise_obj=apprise_obj) + await apprise_obj.async_notify( + body=body, + title=title, + notify_type=notify_type, + body_format=body_format, + tag=tag, + attach=attach, + interpret_escapes=interpret_escapes, + ) + def get_conn(self) -> None: raise NotImplementedError() diff --git a/providers/apprise/src/airflow/providers/apprise/notifications/apprise.py b/providers/apprise/src/airflow/providers/apprise/notifications/apprise.py index e4774df50077b..dd34bce17a84d 100644 --- a/providers/apprise/src/airflow/providers/apprise/notifications/apprise.py +++ b/providers/apprise/src/airflow/providers/apprise/notifications/apprise.py @@ -23,6 +23,7 @@ from apprise import AppriseConfig, NotifyFormat, NotifyType from airflow.providers.apprise.hooks.apprise import AppriseHook +from airflow.providers.apprise.version_compat import AIRFLOW_V_3_1_PLUS from airflow.providers.common.compat.notifier import BaseNotifier @@ -58,8 +59,13 @@ def __init__( interpret_escapes: bool | None = None, config: AppriseConfig | None = None, apprise_conn_id: str = AppriseHook.default_conn_name, + **kwargs, ): - super().__init__() + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.apprise_conn_id = apprise_conn_id self.body = body self.title = title @@ -88,5 +94,18 @@ def notify(self, context): config=self.config, ) + async def async_notify(self, context): + """Send a alert to a apprise configured service.""" + await self.hook.async_notify( + body=self.body, + title=self.title, + notify_type=self.notify_type, + body_format=self.body_format, + tag=self.tag, + attach=self.attach, + interpret_escapes=self.interpret_escapes, + config=self.config, + ) + send_apprise_notification = AppriseNotifier diff --git a/providers/apprise/tests/unit/apprise/hooks/test_apprise.py b/providers/apprise/tests/unit/apprise/hooks/test_apprise.py index 80b2256e3773d..ef01444ee030d 100644 --- a/providers/apprise/tests/unit/apprise/hooks/test_apprise.py +++ b/providers/apprise/tests/unit/apprise/hooks/test_apprise.py @@ -18,7 +18,7 @@ import json from unittest import mock -from unittest.mock import MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import apprise import pytest @@ -42,13 +42,9 @@ class TestAppriseHook: ) def test_get_config_from_conn(self, config): extra = {"config": config} - with patch.object( - AppriseHook, - "get_connection", - return_value=Connection(conn_type="apprise", extra=extra), - ): - hook = AppriseHook() - assert hook.get_config_from_conn() == (json.loads(config) if isinstance(config, str) else config) + conn = Connection(conn_type="apprise", extra=extra) + hook = AppriseHook() + assert hook.get_config_from_conn(conn) == (json.loads(config) if isinstance(config, str) else config) def test_set_config_from_conn_with_dict(self): """ @@ -57,13 +53,9 @@ def test_set_config_from_conn_with_dict(self): extra = {"config": {"path": "http://some_path_that_dont_exist/", "tag": "alert"}} apprise_obj = apprise.Apprise() apprise_obj.add = MagicMock() - with patch.object( - AppriseHook, - "get_connection", - return_value=Connection(conn_type="apprise", extra=extra), - ): - hook = AppriseHook() - hook.set_config_from_conn(apprise_obj) + conn = Connection(conn_type="apprise", extra=extra) + hook = AppriseHook() + hook.set_config_from_conn(conn=conn, apprise_obj=apprise_obj) apprise_obj.add.assert_called_once_with("http://some_path_that_dont_exist/", tag="alert") @@ -80,13 +72,9 @@ def test_set_config_from_conn_with_list(self): apprise_obj = apprise.Apprise() apprise_obj.add = MagicMock() - with patch.object( - AppriseHook, - "get_connection", - return_value=Connection(conn_type="apprise", extra=extra), - ): - hook = AppriseHook() - hook.set_config_from_conn(apprise_obj) + conn = Connection(conn_type="apprise", extra=extra) + hook = AppriseHook() + hook.set_config_from_conn(conn=conn, apprise_obj=apprise_obj) apprise_obj.add.assert_has_calls( [ @@ -97,7 +85,9 @@ def test_set_config_from_conn_with_list(self): @mock.patch( "airflow.providers.apprise.hooks.apprise.AppriseHook.get_connection", - return_value=Connection( + ) + def test_notify(self, mock_conn): + mock_conn.return_value = Connection( conn_id="apprise", extra={ "config": [ @@ -105,9 +95,7 @@ def test_set_config_from_conn_with_list(self): {"path": "http://some_other_path_that_dont_exist/", "tag": "p1"}, ] }, - ), - ) - def test_notify(self, connection): + ) apprise_obj = apprise.Apprise() apprise_obj.notify = MagicMock() apprise_obj.add = MagicMock() @@ -124,3 +112,35 @@ def test_notify(self, connection): attach=None, interpret_escapes=None, ) + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.apprise.hooks.apprise.get_async_connection", + ) + async def test_async_notify(self, mock_conn): + mock_conn.return_value = Connection( + conn_id="apprise", + extra={ + "config": [ + {"path": "http://some_path_that_dont_exist/", "tag": "p0"}, + {"path": "http://some_other_path_that_dont_exist/", "tag": "p1"}, + ] + }, + ) + apprise_obj = apprise.Apprise() + apprise_obj.async_notify = AsyncMock() + apprise_obj.add = MagicMock() + with patch.object(apprise, "Apprise", return_value=apprise_obj): + hook = AppriseHook() + await hook.async_notify(body="test") + + mock_conn.assert_called() + apprise_obj.async_notify.assert_called_once_with( + body="test", + title="", + notify_type=NotifyType.INFO, + body_format=NotifyFormat.TEXT, + tag="all", + attach=None, + interpret_escapes=None, + ) diff --git a/providers/apprise/tests/unit/apprise/notifications/test_apprise.py b/providers/apprise/tests/unit/apprise/notifications/test_apprise.py index e9aa03f1669b7..1b1bbbffb196a 100644 --- a/providers/apprise/tests/unit/apprise/notifications/test_apprise.py +++ b/providers/apprise/tests/unit/apprise/notifications/test_apprise.py @@ -98,3 +98,26 @@ def test_notifier_templated(self, mock_apprise_hook, create_dag_without_db): "config": None, } mock_apprise_hook.return_value.notify.assert_called_once() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.apprise.notifications.apprise.AppriseHook") + async def test_async_apprise_notifier(self, mock_apprise_hook, create_dag_without_db): + mock_apprise_hook.return_value.async_notify = mock.AsyncMock() + + notifier = send_apprise_notification(body="DISK at 99%", notify_type=NotifyType.FAILURE) + + await notifier.async_notify({"dag": create_dag_without_db("test_notifier")}) + + call_args = mock_apprise_hook.return_value.async_notify.call_args.kwargs + + assert call_args == { + "body": "DISK at 99%", + "notify_type": NotifyType.FAILURE, + "title": None, + "body_format": NotifyFormat.TEXT, + "tag": "all", + "attach": None, + "interpret_escapes": None, + "config": None, + } + mock_apprise_hook.return_value.async_notify.assert_called_once()