Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions providers/slack/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dependencies = [
"apache-airflow-providers-common-compat>=1.6.1",
"apache-airflow-providers-common-sql>=1.27.0",
"slack-sdk>=3.36.0",
"asgiref>=2.3.0",
]

[dependency-groups]
Expand Down
43 changes: 39 additions & 4 deletions providers/slack/src/airflow/providers/slack/hooks/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Hook for Slack.

.. spelling:word-list::

AsyncSlackResponse
"""

from __future__ import annotations

import json
Expand All @@ -27,17 +35,21 @@

from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web.async_client import AsyncWebClient
from typing_extensions import NotRequired

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.slack.utils import ConnectionExtraConfig
from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection
from airflow.providers.slack.version_compat import BaseHook
from airflow.utils.helpers import exactly_one

if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler
from slack_sdk.web.async_client import AsyncSlackResponse
from slack_sdk.web.slack_response import SlackResponse

from airflow.providers.slack.version_compat import Connection


class FileUploadTypeDef(TypedDict):
"""
Expand Down Expand Up @@ -140,15 +152,20 @@ def __init__(
@cached_property
def client(self) -> WebClient:
"""Get the underlying slack_sdk.WebClient (cached)."""
return WebClient(**self._get_conn_params())
conn = self.get_connection(self.slack_conn_id)
return WebClient(**self._get_conn_params(conn=conn))

async def get_async_client(self) -> AsyncWebClient:
"""Get the underlying `slack_sdk.web.async_client.AsyncWebClient`."""
conn = await get_async_connection(self.slack_conn_id)
return AsyncWebClient(**self._get_conn_params(conn))

def get_conn(self) -> WebClient:
"""Get the underlying slack_sdk.WebClient (cached)."""
return self.client

def _get_conn_params(self) -> dict[str, Any]:
def _get_conn_params(self, conn: Connection) -> dict[str, Any]:
"""Fetch connection params as a dict and merge it with hook parameters."""
conn = self.get_connection(self.slack_conn_id)
if not conn.password:
raise AirflowNotFoundException(
f"Connection ID {self.slack_conn_id!r} does not contain password (Slack API Token)."
Expand Down Expand Up @@ -186,6 +203,24 @@ def call(self, api_method: str, **kwargs) -> SlackResponse:
"""
return self.client.api_call(api_method, **kwargs)

async def async_call(self, api_method: str, **kwargs) -> AsyncSlackResponse:
"""
Call Slack WebClient `AsyncWebClient.api_call` with given arguments.

:param api_method: The target Slack API method. e.g. 'chat.postMessage'. Required.
:param http_verb: HTTP Verb. Optional (defaults to 'POST')
:param files: Files to multipart upload. e.g. {imageORfile: file_objectORfile_path}
:param data: The body to attach to the request. If a dictionary is provided,
form-encoding will take place. Optional.
:param params: The URL parameters to append to the URL. Optional.
:param json: JSON for the body to attach to the request. Optional.
:return: The server's response to an HTTP request. Data from the response can be
accessed like a dict. If the response included 'next_cursor' it can be
iterated on to execute subsequent requests.
"""
client = await self.get_async_client()
return await client.api_call(api_method, **kwargs)

def send_file_v2(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from slack_sdk.webhook.async_client import AsyncWebhookClient

from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.slack.utils import ConnectionExtraConfig
from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection
from airflow.providers.slack.version_compat import BaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -152,9 +152,8 @@ def client(self) -> WebhookClient:
"""Get the underlying slack_sdk.webhook.WebhookClient (cached)."""
return WebhookClient(**self._get_conn_params())

@cached_property
async def async_client(self) -> AsyncWebhookClient:
"""Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient` (cached)."""
async def get_async_client(self) -> AsyncWebhookClient:
"""Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient`."""
return AsyncWebhookClient(**await self._async_get_conn_params())

def get_conn(self) -> WebhookClient:
Expand All @@ -168,7 +167,7 @@ def _get_conn_params(self) -> dict[str, Any]:

async def _async_get_conn_params(self) -> dict[str, Any]:
"""Fetch connection params as a dict and merge it with hook parameters (async)."""
conn = await self.aget_connection(self.slack_webhook_conn_id)
conn = await get_async_connection(self.slack_webhook_conn_id)
return self._build_conn_params(conn)

def _build_conn_params(self, conn) -> dict[str, Any]:
Expand Down Expand Up @@ -251,7 +250,7 @@ async def async_send_dict(self, body: dict[str, Any] | str, *, headers: dict[str
:param headers: Request headers for this request.
"""
body = self._process_body(body)
async_client = await self.async_client
async_client = await self.get_async_client()
return await async_client.send_dict(body, headers=headers)

def send(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow.providers.common.compat.notifier import BaseNotifier
from airflow.providers.slack.hooks.slack import SlackHook
from airflow.providers.slack.version_compat import AIRFLOW_V_3_1_PLUS

if TYPE_CHECKING:
from slack_sdk.http_retry import RetryHandler
Expand Down Expand Up @@ -71,8 +72,13 @@ def __init__(
retry_handlers: list[RetryHandler] | None = None,
unfurl_links: bool = True,
unfurl_media: bool = True,
**kwargs,
):
super().__init__()
if AIRFLOW_V_3_1_PLUS:
# Support for passing context was added in 3.1.0
super().__init__(**kwargs)
else:
super().__init__()
self.slack_conn_id = slack_conn_id
self.text = text
self.channel = channel
Expand Down Expand Up @@ -112,5 +118,19 @@ def notify(self, context):
}
self.hook.call("chat.postMessage", json=api_call_params)

async def async_notify(self, context):
"""Send a message to a Slack Channel (async)."""
api_call_params = {
"channel": self.channel,
"username": self.username,
"text": self.text,
"icon_url": self.icon_url,
"attachments": json.dumps(self.attachments),
"blocks": json.dumps(self.blocks),
"unfurl_links": self.unfurl_links,
"unfurl_media": self.unfurl_media,
}
await self.hook.async_call("chat.postMessage", json=api_call_params)


send_slack_notification = SlackNotifier
15 changes: 15 additions & 0 deletions providers/slack/src/airflow/providers/slack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from collections.abc import Sequence
from typing import Any

from asgiref.sync import sync_to_async

from airflow.providers.slack.version_compat import BaseHook, Connection
from airflow.utils.types import NOTSET


Expand Down Expand Up @@ -120,3 +123,15 @@ def parse_filename(
if fallback:
return fallback, None
raise ex from None


async def get_async_connection(conn_id: str) -> Connection:
"""
Get an asynchronous Airflow connection that is backwards compatible.

:param conn_id: The provided connection ID.
:returns: Connection
"""
if hasattr(BaseHook, "aget_connection"):
return await BaseHook.aget_connection(conn_id=conn_id)
return await sync_to_async(BaseHook.get_connection)(conn_id=conn_id)
5 changes: 3 additions & 2 deletions providers/slack/src/airflow/providers/slack/version_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import BaseOperator
from airflow.sdk import BaseOperator, Connection
else:
from airflow.models import BaseOperator
from airflow.models import BaseOperator, Connection # type: ignore[assignment]

if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseHook
Expand All @@ -50,4 +50,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
"AIRFLOW_V_3_1_PLUS",
"BaseHook",
"BaseOperator",
"Connection",
]
57 changes: 49 additions & 8 deletions providers/slack/tests/unit/slack/hooks/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,24 @@ def make_429():
def test_get_token_from_connection(self, conn_id):
"""Test retrieve token from Slack API Connection ID."""
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN
conn = hook.get_connection(hook.slack_conn_id)
assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN

def test_resolve_token(self):
"""Test that we only use token from Slack API Connection ID."""
with pytest.warns(UserWarning, match="Provide `token` as part of .* parameters is disallowed"):
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID, token="foo-bar")
conn = hook.get_connection(hook.slack_conn_id)
assert "token" not in hook.extra_client_args
assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN
assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN

def test_empty_password(self):
"""Test password field defined in the connection."""
hook = SlackHook(slack_conn_id="empty_slack_connection")
conn = hook.get_connection(hook.slack_conn_id)
error_message = r"Connection ID '.*' does not contain password \(Slack API Token\)\."
with pytest.raises(AirflowNotFoundException, match=error_message):
hook._get_conn_params()
hook._get_conn_params(conn)

@pytest.mark.parametrize(
"hook_config,conn_extra,expected",
Expand Down Expand Up @@ -228,8 +231,9 @@ def test_client_configuration(

with mock.patch.dict("os.environ", values={test_conn_env: test_conn.get_uri()}):
hook = SlackHook(slack_conn_id=test_conn.conn_id, **hook_config)
conn = hook.get_connection(hook.slack_conn_id)
expected["logger"] = hook.log
conn_params = hook._get_conn_params()
conn_params = hook._get_conn_params(conn)
assert conn_params == expected

client = hook.client
Expand Down Expand Up @@ -319,7 +323,8 @@ def test_hook_connection_failed(self, mocked_client, response_data):
def test_backcompat_prefix_works(self, uri, monkeypatch):
monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", uri)
hook = SlackHook(slack_conn_id="my_conn")
params = hook._get_conn_params()
conn = hook.get_connection(hook.slack_conn_id)
params = hook._get_conn_params(conn)
assert params["token"] == "abc"
assert params["timeout"] == 123
assert params["base_url"] == "base_url"
Expand All @@ -328,8 +333,9 @@ def test_backcompat_prefix_works(self, uri, monkeypatch):
def test_backcompat_prefix_both_causes_warning(self, monkeypatch):
monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", "a://:abc@?extra__slack__timeout=111&timeout=222")
hook = SlackHook(slack_conn_id="my_conn")
conn = hook.get_connection(hook.slack_conn_id)
with pytest.warns(Warning, match="Using value for `timeout`"):
params = hook._get_conn_params()
params = hook._get_conn_params(conn)
assert params["timeout"] == 222

def test_empty_string_ignored_prefixed(self, monkeypatch):
Expand All @@ -340,7 +346,8 @@ def test_empty_string_ignored_prefixed(self, monkeypatch):
),
)
hook = SlackHook(slack_conn_id="my_conn")
params = hook._get_conn_params()
conn = hook.get_connection(hook.slack_conn_id)
params = hook._get_conn_params(conn)
assert "proxy" not in params
assert "base_url" not in params

Expand All @@ -350,7 +357,8 @@ def test_empty_string_ignored_non_prefixed(self, monkeypatch):
json.dumps({"password": "hi", "extra": {"base_url": "", "proxy": ""}}),
)
hook = SlackHook(slack_conn_id="my_conn")
params = hook._get_conn_params()
conn = hook.get_connection(hook.slack_conn_id)
params = hook._get_conn_params(conn)
assert "proxy" not in params
assert "base_url" not in params

Expand Down Expand Up @@ -539,3 +547,36 @@ def test_send_file_v1_to_v2_multiple_channels(self, channels, expected_calls):
with mock.patch.object(SlackHook, "send_file_v2") as mocked_send_file_v2:
hook.send_file_v1_to_v2(channels=channels, content="Fake")
assert mocked_send_file_v2.call_count == expected_calls


class TestSlackHookAsync:
@pytest.fixture
def mock_get_conn(self):
with mock.patch(
"airflow.providers.slack.hooks.slack.get_async_connection", new_callable=mock.AsyncMock
) as m:
m.return_value = Connection(
conn_id=SLACK_API_DEFAULT_CONN_ID,
conn_type=CONN_TYPE,
password=MOCK_SLACK_API_TOKEN,
)
yield m

@pytest.mark.asyncio
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient")
async def test_get_async_client(self, mock_client, mock_get_conn):
"""Test get_async_client creates AsyncWebClient with correct params."""
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
await hook.get_async_client()
mock_get_conn.assert_called()
mock_client.assert_called_once_with(token=MOCK_SLACK_API_TOKEN, logger=mock.ANY)

@pytest.mark.asyncio
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient.api_call", new_callable=mock.AsyncMock)
async def test_async_call(self, mock_api_call, mock_get_conn):
"""Test async_call is called correctly."""
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
test_api_json = {"channel": "test_channel"}
await hook.async_call("chat.postMessage", json=test_api_json)
mock_get_conn.assert_called()
mock_api_call.assert_called_with("chat.postMessage", json=test_api_json)
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ async def test_async_client(self, mock_async_get_conn_params):
mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL}

hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
client = await hook.async_client
client = await hook.get_async_client()

assert isinstance(client, AsyncWebhookClient)
assert client.url == TEST_WEBHOOK_URL
Expand Down
28 changes: 28 additions & 0 deletions providers/slack/tests/unit/slack/notifications/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,31 @@ def test_slack_notifier_unfurl_options(self, mock_slack_hook, create_dag_without
"unfurl_media": False,
},
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.slack.notifications.slack.SlackHook")
async def test_async_slack_notifier(self, mock_slack_hook):
mock_slack_hook.return_value.async_call = mock.AsyncMock()

notifier = send_slack_notification(
text="test",
unfurl_links=False,
unfurl_media=False,
)

await notifier.async_notify({})

mock_slack_hook.return_value.async_call.assert_called_once_with(
"chat.postMessage",
json={
"channel": "#general",
"username": "Airflow",
"text": "test",
"icon_url": "https://raw.githubusercontent.com/apache/airflow/main/airflow-core"
"/src/airflow/ui/public/pin_100.png",
"attachments": "[]",
"blocks": "[]",
"unfurl_links": False,
"unfurl_media": False,
},
)
Loading
Loading