Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 92 additions & 4 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,19 +1261,107 @@ def render(key: str, content: str) -> str:
return subject, html_content, html_content_err


def _get_email_address_list(addresses: str | Iterable[str]) -> list[str]:
"""
Return a list of email addresses from the provided input.

:param addresses: A string or iterable of strings containing email addresses.
:return: A list of email addresses.
:raises TypeError: If the input is not a string or iterable of strings.
"""
import collections.abc
import re

def _get_email_list_from_str(addresses: str) -> list[str]:
"""
Extract a list of email addresses from a string.

The string can contain multiple email addresses separated
by any of the following delimiters: ',' or ';'.

:param addresses: A string containing one or more email addresses.
:return: A list of email addresses.
"""
pattern = r"\s*[,;]\s*"
return re.split(pattern, addresses)

if isinstance(addresses, str):
return _get_email_list_from_str(addresses)
if isinstance(addresses, collections.abc.Iterable):
if not all(isinstance(item, str) for item in addresses):
raise TypeError("The items in your iterable must be strings.")
return list(addresses)
raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.")


def _send_email(
to: list[str] | Iterable[str],
subject: str,
html_content: str,
files: list[str] | None = None,
dryrun: bool = False,
cc: str | Iterable[str] | None = None,
bcc: str | Iterable[str] | None = None,
mime_subtype: str = "mixed",
mime_charset: str = "utf-8",
conn_id: str | None = None,
custom_headers: dict[str, Any] | None = None,
**kwargs,
) -> None:
"""
Send an email using the backend specified in the *EMAIL_BACKEND* configuration option.

:param to: A list or iterable of email addresses to send the email to.
:param subject: The subject of the email.
:param html_content: The content of the email in HTML format.
:param files: A list of paths to files to attach to the email.
:param dryrun: If *True*, the email will not actually be sent. Default: *False*.
:param cc: A string or iterable of strings containing email addresses to send a copy of the email to.
:param bcc: A string or iterable of strings containing email addresses to send a
blind carbon copy of the email to.
:param mime_subtype: The subtype of the MIME message. Default: "mixed".
:param mime_charset: The charset of the email. Default: "utf-8".
:param conn_id: The connection ID to use for the backend. If not provided, the default connection
specified in the *EMAIL_CONN_ID* configuration option will be used.
:param custom_headers: A dictionary of additional headers to add to the MIME message.
No validations are run on these values, and they should be able to be encoded.
:param kwargs: Additional keyword arguments to pass to the backend.
"""
backend = conf.getimport("email", "EMAIL_BACKEND")
backend_conn_id = conn_id or conf.get("email", "EMAIL_CONN_ID")
from_email = conf.get("email", "from_email", fallback=None)

to_list = _get_email_address_list(to)
to_comma_separated = ", ".join(to_list)

return backend(
to_comma_separated,
subject,
html_content,
files=files,
dryrun=dryrun,
cc=cc,
bcc=bcc,
mime_subtype=mime_subtype,
mime_charset=mime_charset,
conn_id=backend_conn_id,
from_email=from_email,
custom_headers=custom_headers,
**kwargs,
)


def _send_task_error_email(
to: Iterable[str],
ti: RuntimeTaskInstance,
exception: BaseException | str | None,
log: Logger,
) -> None:
from airflow.utils.email import send_email

subject, content, err = _get_email_subject_content(task_instance=ti, exception=exception, log=log)
try:
send_email(to, subject, content)
_send_email(to, subject, content)
except Exception:
send_email(to, subject, err)
_send_email(to, subject, err)


def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger):
Expand Down
38 changes: 38 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,44 @@ def test_xcom_clearing_without_keys_to_clear(self, create_runtime_ti, mock_super

mock_delete.assert_not_called()

def test_send_email_on_failure(self, create_runtime_ti, mock_supervisor_comms):
"""Test that _send_email is called when task fails with email_on_failure=True."""

class FailingOperator(BaseOperator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.email_on_failure = True
self.email = ["test@example.com"]
self.max_retries = 0 # No retries, fail immediately

def execute(self, context):
raise AirflowException("Task failed")

task = FailingOperator(task_id="failing_task")
runtime_ti = create_runtime_ti(task=task, dag_id="test_email_dag")
runtime_ti._ti_context_from_server = TIRunContext(
dag_run=runtime_ti._ti_context_from_server.dag_run,
task_reschedule_count=0,
max_tries=1,
should_retry=False,
)

with mock.patch("airflow.sdk.execution_time.task_runner._send_email") as mock_send_email:
state, error, _ = run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock())
finalize(
runtime_ti,
state,
context=runtime_ti.get_template_context(),
log=mock.MagicMock(),
error=error,
)

mock_send_email.assert_called_once()
args, kwargs = mock_send_email.call_args
assert args[0] == ["test@example.com"]
assert args[1]
assert args[2]


class TestXComAfterTaskExecution:
@pytest.mark.parametrize(
Expand Down