diff --git a/providers/slack/pyproject.toml b/providers/slack/pyproject.toml index 01059adc336e3..97c707335bf7f 100644 --- a/providers/slack/pyproject.toml +++ b/providers/slack/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "apache-airflow-providers-common-compat>=1.6.1", "apache-airflow-providers-common-sql>=1.27.0", "slack-sdk>=3.36.0", + "asgiref>=2.3.0", ] [dependency-groups] diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack.py b/providers/slack/src/airflow/providers/slack/hooks/slack.py index e09ff44c1434c..13e052a35a66e 100644 --- a/providers/slack/src/airflow/providers/slack/hooks/slack.py +++ b/providers/slack/src/airflow/providers/slack/hooks/slack.py @@ -15,6 +15,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Hook for Slack. + +.. spelling:word-list:: + + AsyncSlackResponse +""" + from __future__ import annotations import json @@ -27,17 +35,21 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient from typing_extensions import NotRequired from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.providers.slack.utils import ConnectionExtraConfig +from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection from airflow.providers.slack.version_compat import BaseHook from airflow.utils.helpers import exactly_one if TYPE_CHECKING: from slack_sdk.http_retry import RetryHandler + from slack_sdk.web.async_client import AsyncSlackResponse from slack_sdk.web.slack_response import SlackResponse + from airflow.providers.slack.version_compat import Connection + class FileUploadTypeDef(TypedDict): """ @@ -140,15 +152,20 @@ def __init__( @cached_property def client(self) -> WebClient: """Get the underlying slack_sdk.WebClient (cached).""" - return WebClient(**self._get_conn_params()) + conn = self.get_connection(self.slack_conn_id) + return WebClient(**self._get_conn_params(conn=conn)) + + async def get_async_client(self) -> AsyncWebClient: + """Get the underlying `slack_sdk.web.async_client.AsyncWebClient`.""" + conn = await get_async_connection(self.slack_conn_id) + return AsyncWebClient(**self._get_conn_params(conn)) def get_conn(self) -> WebClient: """Get the underlying slack_sdk.WebClient (cached).""" return self.client - def _get_conn_params(self) -> dict[str, Any]: + def _get_conn_params(self, conn: Connection) -> dict[str, Any]: """Fetch connection params as a dict and merge it with hook parameters.""" - conn = self.get_connection(self.slack_conn_id) if not conn.password: raise AirflowNotFoundException( f"Connection ID {self.slack_conn_id!r} does not contain password (Slack API Token)." @@ -186,6 +203,24 @@ def call(self, api_method: str, **kwargs) -> SlackResponse: """ return self.client.api_call(api_method, **kwargs) + async def async_call(self, api_method: str, **kwargs) -> AsyncSlackResponse: + """ + Call Slack WebClient `AsyncWebClient.api_call` with given arguments. + + :param api_method: The target Slack API method. e.g. 'chat.postMessage'. Required. + :param http_verb: HTTP Verb. Optional (defaults to 'POST') + :param files: Files to multipart upload. e.g. {imageORfile: file_objectORfile_path} + :param data: The body to attach to the request. If a dictionary is provided, + form-encoding will take place. Optional. + :param params: The URL parameters to append to the URL. Optional. + :param json: JSON for the body to attach to the request. Optional. + :return: The server's response to an HTTP request. Data from the response can be + accessed like a dict. If the response included 'next_cursor' it can be + iterated on to execute subsequent requests. + """ + client = await self.get_async_client() + return await client.api_call(api_method, **kwargs) + def send_file_v2( self, *, diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py index 9c2260b2936ee..e707c1e49d384 100644 --- a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py @@ -27,7 +27,7 @@ from slack_sdk.webhook.async_client import AsyncWebhookClient from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.providers.slack.utils import ConnectionExtraConfig +from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection from airflow.providers.slack.version_compat import BaseHook if TYPE_CHECKING: @@ -152,9 +152,8 @@ def client(self) -> WebhookClient: """Get the underlying slack_sdk.webhook.WebhookClient (cached).""" return WebhookClient(**self._get_conn_params()) - @cached_property - async def async_client(self) -> AsyncWebhookClient: - """Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient` (cached).""" + async def get_async_client(self) -> AsyncWebhookClient: + """Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient`.""" return AsyncWebhookClient(**await self._async_get_conn_params()) def get_conn(self) -> WebhookClient: @@ -168,7 +167,7 @@ def _get_conn_params(self) -> dict[str, Any]: async def _async_get_conn_params(self) -> dict[str, Any]: """Fetch connection params as a dict and merge it with hook parameters (async).""" - conn = await self.aget_connection(self.slack_webhook_conn_id) + conn = await get_async_connection(self.slack_webhook_conn_id) return self._build_conn_params(conn) def _build_conn_params(self, conn) -> dict[str, Any]: @@ -251,7 +250,7 @@ async def async_send_dict(self, body: dict[str, Any] | str, *, headers: dict[str :param headers: Request headers for this request. """ body = self._process_body(body) - async_client = await self.async_client + async_client = await self.get_async_client() return await async_client.send_dict(body, headers=headers) def send( diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack.py b/providers/slack/src/airflow/providers/slack/notifications/slack.py index 8043a8d40a8b1..82fe6af649d0d 100644 --- a/providers/slack/src/airflow/providers/slack/notifications/slack.py +++ b/providers/slack/src/airflow/providers/slack/notifications/slack.py @@ -24,6 +24,7 @@ from airflow.providers.common.compat.notifier import BaseNotifier from airflow.providers.slack.hooks.slack import SlackHook +from airflow.providers.slack.version_compat import AIRFLOW_V_3_1_PLUS if TYPE_CHECKING: from slack_sdk.http_retry import RetryHandler @@ -71,8 +72,13 @@ def __init__( retry_handlers: list[RetryHandler] | None = None, unfurl_links: bool = True, unfurl_media: bool = True, + **kwargs, ): - super().__init__() + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.slack_conn_id = slack_conn_id self.text = text self.channel = channel @@ -112,5 +118,19 @@ def notify(self, context): } self.hook.call("chat.postMessage", json=api_call_params) + async def async_notify(self, context): + """Send a message to a Slack Channel (async).""" + api_call_params = { + "channel": self.channel, + "username": self.username, + "text": self.text, + "icon_url": self.icon_url, + "attachments": json.dumps(self.attachments), + "blocks": json.dumps(self.blocks), + "unfurl_links": self.unfurl_links, + "unfurl_media": self.unfurl_media, + } + await self.hook.async_call("chat.postMessage", json=api_call_params) + send_slack_notification = SlackNotifier diff --git a/providers/slack/src/airflow/providers/slack/utils/__init__.py b/providers/slack/src/airflow/providers/slack/utils/__init__.py index 6c59b85c0531b..a50bee4c66adb 100644 --- a/providers/slack/src/airflow/providers/slack/utils/__init__.py +++ b/providers/slack/src/airflow/providers/slack/utils/__init__.py @@ -20,6 +20,9 @@ from collections.abc import Sequence from typing import Any +from asgiref.sync import sync_to_async + +from airflow.providers.slack.version_compat import BaseHook, Connection from airflow.utils.types import NOTSET @@ -120,3 +123,15 @@ def parse_filename( if fallback: return fallback, None raise ex from None + + +async def get_async_connection(conn_id: str) -> Connection: + """ + Get an asynchronous Airflow connection that is backwards compatible. + + :param conn_id: The provided connection ID. + :returns: Connection + """ + if hasattr(BaseHook, "aget_connection"): + return await BaseHook.aget_connection(conn_id=conn_id) + return await sync_to_async(BaseHook.get_connection)(conn_id=conn_id) diff --git a/providers/slack/src/airflow/providers/slack/version_compat.py b/providers/slack/src/airflow/providers/slack/version_compat.py index 6aeb90cc89dc8..7ee3d88862700 100644 --- a/providers/slack/src/airflow/providers/slack/version_compat.py +++ b/providers/slack/src/airflow/providers/slack/version_compat.py @@ -36,9 +36,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0) if AIRFLOW_V_3_0_PLUS: - from airflow.sdk import BaseOperator + from airflow.sdk import BaseOperator, Connection else: - from airflow.models import BaseOperator + from airflow.models import BaseOperator, Connection # type: ignore[assignment] if AIRFLOW_V_3_1_PLUS: from airflow.sdk import BaseHook @@ -50,4 +50,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]: "AIRFLOW_V_3_1_PLUS", "BaseHook", "BaseOperator", + "Connection", ] diff --git a/providers/slack/tests/unit/slack/hooks/test_slack.py b/providers/slack/tests/unit/slack/hooks/test_slack.py index 3ce39f1218552..85e0025c4d1fe 100644 --- a/providers/slack/tests/unit/slack/hooks/test_slack.py +++ b/providers/slack/tests/unit/slack/hooks/test_slack.py @@ -105,21 +105,24 @@ def make_429(): def test_get_token_from_connection(self, conn_id): """Test retrieve token from Slack API Connection ID.""" hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID) - assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN + conn = hook.get_connection(hook.slack_conn_id) + assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN def test_resolve_token(self): """Test that we only use token from Slack API Connection ID.""" with pytest.warns(UserWarning, match="Provide `token` as part of .* parameters is disallowed"): hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID, token="foo-bar") + conn = hook.get_connection(hook.slack_conn_id) assert "token" not in hook.extra_client_args - assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN + assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN def test_empty_password(self): """Test password field defined in the connection.""" hook = SlackHook(slack_conn_id="empty_slack_connection") + conn = hook.get_connection(hook.slack_conn_id) error_message = r"Connection ID '.*' does not contain password \(Slack API Token\)\." with pytest.raises(AirflowNotFoundException, match=error_message): - hook._get_conn_params() + hook._get_conn_params(conn) @pytest.mark.parametrize( "hook_config,conn_extra,expected", @@ -228,8 +231,9 @@ def test_client_configuration( with mock.patch.dict("os.environ", values={test_conn_env: test_conn.get_uri()}): hook = SlackHook(slack_conn_id=test_conn.conn_id, **hook_config) + conn = hook.get_connection(hook.slack_conn_id) expected["logger"] = hook.log - conn_params = hook._get_conn_params() + conn_params = hook._get_conn_params(conn) assert conn_params == expected client = hook.client @@ -319,7 +323,8 @@ def test_hook_connection_failed(self, mocked_client, response_data): def test_backcompat_prefix_works(self, uri, monkeypatch): monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", uri) hook = SlackHook(slack_conn_id="my_conn") - params = hook._get_conn_params() + conn = hook.get_connection(hook.slack_conn_id) + params = hook._get_conn_params(conn) assert params["token"] == "abc" assert params["timeout"] == 123 assert params["base_url"] == "base_url" @@ -328,8 +333,9 @@ def test_backcompat_prefix_works(self, uri, monkeypatch): def test_backcompat_prefix_both_causes_warning(self, monkeypatch): monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", "a://:abc@?extra__slack__timeout=111&timeout=222") hook = SlackHook(slack_conn_id="my_conn") + conn = hook.get_connection(hook.slack_conn_id) with pytest.warns(Warning, match="Using value for `timeout`"): - params = hook._get_conn_params() + params = hook._get_conn_params(conn) assert params["timeout"] == 222 def test_empty_string_ignored_prefixed(self, monkeypatch): @@ -340,7 +346,8 @@ def test_empty_string_ignored_prefixed(self, monkeypatch): ), ) hook = SlackHook(slack_conn_id="my_conn") - params = hook._get_conn_params() + conn = hook.get_connection(hook.slack_conn_id) + params = hook._get_conn_params(conn) assert "proxy" not in params assert "base_url" not in params @@ -350,7 +357,8 @@ def test_empty_string_ignored_non_prefixed(self, monkeypatch): json.dumps({"password": "hi", "extra": {"base_url": "", "proxy": ""}}), ) hook = SlackHook(slack_conn_id="my_conn") - params = hook._get_conn_params() + conn = hook.get_connection(hook.slack_conn_id) + params = hook._get_conn_params(conn) assert "proxy" not in params assert "base_url" not in params @@ -539,3 +547,36 @@ def test_send_file_v1_to_v2_multiple_channels(self, channels, expected_calls): with mock.patch.object(SlackHook, "send_file_v2") as mocked_send_file_v2: hook.send_file_v1_to_v2(channels=channels, content="Fake") assert mocked_send_file_v2.call_count == expected_calls + + +class TestSlackHookAsync: + @pytest.fixture + def mock_get_conn(self): + with mock.patch( + "airflow.providers.slack.hooks.slack.get_async_connection", new_callable=mock.AsyncMock + ) as m: + m.return_value = Connection( + conn_id=SLACK_API_DEFAULT_CONN_ID, + conn_type=CONN_TYPE, + password=MOCK_SLACK_API_TOKEN, + ) + yield m + + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient") + async def test_get_async_client(self, mock_client, mock_get_conn): + """Test get_async_client creates AsyncWebClient with correct params.""" + hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID) + await hook.get_async_client() + mock_get_conn.assert_called() + mock_client.assert_called_once_with(token=MOCK_SLACK_API_TOKEN, logger=mock.ANY) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient.api_call", new_callable=mock.AsyncMock) + async def test_async_call(self, mock_api_call, mock_get_conn): + """Test async_call is called correctly.""" + hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID) + test_api_json = {"channel": "test_channel"} + await hook.async_call("chat.postMessage", json=test_api_json) + mock_get_conn.assert_called() + mock_api_call.assert_called_with("chat.postMessage", json=test_api_json) diff --git a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py index 43c97045b4043..81dd1441d91c2 100644 --- a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py +++ b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py @@ -555,7 +555,7 @@ async def test_async_client(self, mock_async_get_conn_params): mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL} hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID) - client = await hook.async_client + client = await hook.get_async_client() assert isinstance(client, AsyncWebhookClient) assert client.url == TEST_WEBHOOK_URL diff --git a/providers/slack/tests/unit/slack/notifications/test_slack.py b/providers/slack/tests/unit/slack/notifications/test_slack.py index 134327f44a3b9..1985d3abef203 100644 --- a/providers/slack/tests/unit/slack/notifications/test_slack.py +++ b/providers/slack/tests/unit/slack/notifications/test_slack.py @@ -133,3 +133,31 @@ def test_slack_notifier_unfurl_options(self, mock_slack_hook, create_dag_without "unfurl_media": False, }, ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.slack.notifications.slack.SlackHook") + async def test_async_slack_notifier(self, mock_slack_hook): + mock_slack_hook.return_value.async_call = mock.AsyncMock() + + notifier = send_slack_notification( + text="test", + unfurl_links=False, + unfurl_media=False, + ) + + await notifier.async_notify({}) + + mock_slack_hook.return_value.async_call.assert_called_once_with( + "chat.postMessage", + json={ + "channel": "#general", + "username": "Airflow", + "text": "test", + "icon_url": "https://raw.githubusercontent.com/apache/airflow/main/airflow-core" + "/src/airflow/ui/public/pin_100.png", + "attachments": "[]", + "blocks": "[]", + "unfurl_links": False, + "unfurl_media": False, + }, + ) diff --git a/providers/slack/tests/unit/slack/utils/test_utils.py b/providers/slack/tests/unit/slack/utils/test_utils.py index bff3dbc658e80..9d075952f4efa 100644 --- a/providers/slack/tests/unit/slack/utils/test_utils.py +++ b/providers/slack/tests/unit/slack/utils/test_utils.py @@ -16,9 +16,12 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest -from airflow.providers.slack.utils import ConnectionExtraConfig, parse_filename +from airflow.models.connection import Connection +from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection, parse_filename class TestConnectionExtra: @@ -144,3 +147,41 @@ def test_fallback(self, filename, fallback): def test_wrong_fallback(self, filename): with pytest.raises(ValueError, match="Invalid fallback value"): assert parse_filename(filename, self.SUPPORTED_FORMAT, "mp4") + + +class MockAgetBaseHook: + def __init__(*args, **kargs): + pass + + async def aget_connection(self, conn_id: str): + return Connection( + conn_id="test_conn", + conn_type="slack", + password="secret_token_aget", + ) + + +class MockBaseHook: + def __init__(*args, **kargs): + pass + + def get_connection(self, conn_id: str): + return Connection( + conn_id="test_conn_sync", + conn_type="slack", + password="secret_token", + ) + + +class TestGetAsyncConnection: + @mock.patch("airflow.providers.slack.utils.BaseHook", new_callable=MockAgetBaseHook) + @pytest.mark.asyncio + async def test_get_async_connection_with_aget(self, mock_hook): + conn = await get_async_connection("test_conn") + assert conn.password == "secret_token_aget" + + @mock.patch("airflow.providers.slack.utils.BaseHook", new_callable=MockBaseHook) + @pytest.mark.asyncio + async def test_get_async_connection_with_get_connection(self, mock_hook): + conn = await get_async_connection("test_conn") + assert conn.password == "secret_token"