Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/discord/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-compat>=1.8.0",
"apache-airflow-providers-common-compat>=1.8.0", # use next version
"apache-airflow-providers-http",
]

Expand Down
196 changes: 153 additions & 43 deletions providers/discord/src/airflow/providers/discord/hooks/discord_webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,70 @@

import json
import re
from typing import Any
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
import aiohttp

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Connection


class DiscordCommonHandler:
"""Contains the common functionality."""

def get_webhook_endpoint(self, conn: Connection | None, webhook_endpoint: str | None) -> str:
"""
Return the default webhook endpoint or override if a webhook_endpoint is manually supplied.

:param conn: Airflow Discord connection
:param webhook_endpoint: The manually provided webhook endpoint
:return: Webhook endpoint (str) to use
"""
if webhook_endpoint:
endpoint = webhook_endpoint
elif conn:
extra = conn.extra_dejson
endpoint = extra.get("webhook_endpoint", "")
else:
raise ValueError(
"Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied."
)

# make sure endpoint matches the expected Discord webhook format
if not re.fullmatch("webhooks/[0-9]+/[a-zA-Z0-9_-]+", endpoint):
raise ValueError(
'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".'
)

return endpoint

def build_discord_payload(
self, *, tts: bool, message: str, username: str | None, avatar_url: str | None
) -> str:
"""
Build a valid Discord JSON payload.

:param tts: Is a text-to-speech message
:param message: The message you want to send to your Discord channel
(max 2000 characters)
:param username: Override the default username of the webhook
:param avatar_url: Override the default avatar of the webhook
:return: Discord payload (str) to send
"""
if len(message) > 2000:
raise ValueError("Discord message length must be 2000 or fewer characters.")
payload: dict[str, Any] = {
"content": message,
"tts": tts,
}
if username:
payload["username"] = username
if avatar_url:
payload["avatar_url"] = avatar_url
return json.dumps(payload)


class DiscordWebhookHook(HttpHook):
Expand Down Expand Up @@ -84,6 +144,7 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.handler = DiscordCommonHandler()
self.http_conn_id: Any = http_conn_id
self.webhook_endpoint = self._get_webhook_endpoint(http_conn_id, webhook_endpoint)
self.message = message
Expand All @@ -100,46 +161,10 @@ def _get_webhook_endpoint(self, http_conn_id: str | None, webhook_endpoint: str
:param webhook_endpoint: The manually provided webhook endpoint
:return: Webhook endpoint (str) to use
"""
if webhook_endpoint:
endpoint = webhook_endpoint
elif http_conn_id:
conn = None
if not webhook_endpoint and http_conn_id:
conn = self.get_connection(http_conn_id)
extra = conn.extra_dejson
endpoint = extra.get("webhook_endpoint", "")
else:
raise AirflowException(
"Cannot get webhook endpoint: No valid Discord webhook endpoint or http_conn_id supplied."
)

# make sure endpoint matches the expected Discord webhook format
if not re.fullmatch("webhooks/[0-9]+/[a-zA-Z0-9_-]+", endpoint):
raise AirflowException(
'Expected Discord webhook endpoint in the form of "webhooks/{webhook.id}/{webhook.token}".'
)

return endpoint

def _build_discord_payload(self) -> str:
"""
Combine all relevant parameters into a valid Discord JSON payload.

:return: Discord payload (str) to send
"""
payload: dict[str, Any] = {}

if self.username:
payload["username"] = self.username
if self.avatar_url:
payload["avatar_url"] = self.avatar_url

payload["tts"] = self.tts

if len(self.message) <= 2000:
payload["content"] = self.message
else:
raise AirflowException("Discord message length must be 2000 or fewer characters.")

return json.dumps(payload)
return self.handler.get_webhook_endpoint(conn, webhook_endpoint)

def execute(self) -> None:
"""Execute the Discord webhook call."""
Expand All @@ -148,11 +173,96 @@ def execute(self) -> None:
# we only need https proxy for Discord
proxies = {"https": self.proxy}

discord_payload = self._build_discord_payload()
discord_payload = self.handler.build_discord_payload(
tts=self.tts, message=self.message, username=self.username, avatar_url=self.avatar_url
)

self.run(
endpoint=self.webhook_endpoint,
data=discord_payload,
headers={"Content-type": "application/json"},
extra_options={"proxies": proxies},
)


class DiscordWebhookAsyncHook(HttpAsyncHook):
"""
This hook allows you to post messages to Discord using incoming webhooks using async HTTP.

Takes a Discord connection ID with a default relative webhook endpoint. The
default endpoint can be overridden using the webhook_endpoint parameter
(https://discordapp.com/developers/docs/resources/webhook).

Each Discord webhook can be pre-configured to use a specific username and
avatar_url. You can override these defaults in this hook.

:param http_conn_id: Http connection ID with host as "https://discord.com/api/" and
default webhook endpoint in the extra field in the form of
{"webhook_endpoint": "webhooks/{webhook.id}/{webhook.token}"}
:param webhook_endpoint: Discord webhook endpoint in the form of
"webhooks/{webhook.id}/{webhook.token}"
:param message: The message you want to send to your Discord channel
(max 2000 characters)
:param username: Override the default username of the webhook
:param avatar_url: Override the default avatar of the webhook
:param tts: Is a text-to-speech message
:param proxy: Proxy to use to make the Discord webhook call
"""

default_headers = {
"Content-Type": "application/json",
}
conn_name_attr = "http_conn_id"
default_conn_name = "discord_default"
conn_type = "discord"
hook_name = "Async Discord"

def __init__(
self,
*,
http_conn_id: str = "",
webhook_endpoint: str | None = None,
message: str = "",
username: str | None = None,
avatar_url: str | None = None,
tts: bool = False,
proxy: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.http_conn_id = http_conn_id
self.webhook_endpoint = webhook_endpoint
self.message = message
self.username = username
self.avatar_url = avatar_url
self.tts = tts
self.proxy = proxy
self.handler = DiscordCommonHandler()

async def _get_webhook_endpoint(self) -> str:
"""
Return the default webhook endpoint or override if a webhook_endpoint is manually supplied.

:param http_conn_id: The provided connection ID
:param webhook_endpoint: The manually provided webhook endpoint
:return: Webhook endpoint (str) to use
"""
conn = None
if not self.webhook_endpoint and self.http_conn_id:
conn = await get_async_connection(self.http_conn_id)
return self.handler.get_webhook_endpoint(conn, self.webhook_endpoint)

async def execute(self) -> None:
"""Execute the Discord webhook call."""
webhook_endpoint = await self._get_webhook_endpoint()
discord_payload = self.handler.build_discord_payload(
tts=self.tts, message=self.message, username=self.username, avatar_url=self.avatar_url
)

async with aiohttp.ClientSession(proxy=self.proxy) as session:
await super().run(
session=session,
endpoint=webhook_endpoint,
data=discord_payload,
headers=self.default_headers,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from functools import cached_property

from airflow.providers.common.compat.notifier import BaseNotifier
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook
from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookAsyncHook, DiscordWebhookHook
from airflow.providers.discord.version_compat import AIRFLOW_V_3_1_PLUS

ICON_URL: str = (
"https://raw.githubusercontent.com/apache/airflow/main/airflow-core/src/airflow/ui/public/pin_100.png"
Expand Down Expand Up @@ -50,8 +51,13 @@ def __init__(
username: str = "Airflow",
avatar_url: str = ICON_URL,
tts: bool = False,
**kwargs,
):
super().__init__()
if AIRFLOW_V_3_1_PLUS:
# Support for passing context was added in 3.1.0
super().__init__(**kwargs)
else:
super().__init__()
self.discord_conn_id = discord_conn_id
self.text = text
self.username = username
Expand All @@ -66,11 +72,36 @@ def hook(self) -> DiscordWebhookHook:
"""Discord Webhook Hook."""
return DiscordWebhookHook(http_conn_id=self.discord_conn_id)

@cached_property
def hook_async(self) -> DiscordWebhookAsyncHook:
"""Discord Webhook Async Hook."""
return DiscordWebhookAsyncHook(
http_conn_id=self.discord_conn_id,
message=self.text,
username=self.username,
avatar_url=self.avatar_url,
tts=self.tts,
)

def notify(self, context):
"""Send a message to a Discord channel."""
"""
Send a message to a Discord channel.

:param context: the context object
:return: None
"""
self.hook.username = self.username
self.hook.message = self.text
self.hook.avatar_url = self.avatar_url
self.hook.tts = self.tts

self.hook.execute()

async def async_notify(self, context) -> None:
"""
Send a message to a Discord channel using async HTTP.

:param context: the context object
:return: None
"""
await self.hook_async.execute()
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,10 @@ def hook(self) -> DiscordWebhookHook:
return hook

def execute(self, context: Context) -> None:
"""Call the DiscordWebhookHook to post a message."""
"""
Call the DiscordWebhookHook to post a message.

:param context: the context object
:return: None
"""
self.hook.execute()
42 changes: 42 additions & 0 deletions providers/discord/src/airflow/providers/discord/version_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID ADDING UNNECESSARY
# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR PROVIDER THAT DEPENDS
# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT
# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE
#
from __future__ import annotations


def get_base_airflow_version_tuple() -> tuple[int, int, int]:
from packaging.version import Version

from airflow import __version__

airflow_version = Version(__version__)
return airflow_version.major, airflow_version.minor, airflow_version.micro


AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0)
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)


__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
]
Loading