diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 02d5eb2399677..d54c600ff0436 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1261,19 +1261,107 @@ def render(key: str, content: str) -> str: return subject, html_content, html_content_err +def _get_email_address_list(addresses: str | Iterable[str]) -> list[str]: + """ + Return a list of email addresses from the provided input. + + :param addresses: A string or iterable of strings containing email addresses. + :return: A list of email addresses. + :raises TypeError: If the input is not a string or iterable of strings. + """ + import collections.abc + import re + + def _get_email_list_from_str(addresses: str) -> list[str]: + """ + Extract a list of email addresses from a string. + + The string can contain multiple email addresses separated + by any of the following delimiters: ',' or ';'. + + :param addresses: A string containing one or more email addresses. + :return: A list of email addresses. + """ + pattern = r"\s*[,;]\s*" + return re.split(pattern, addresses) + + if isinstance(addresses, str): + return _get_email_list_from_str(addresses) + if isinstance(addresses, collections.abc.Iterable): + if not all(isinstance(item, str) for item in addresses): + raise TypeError("The items in your iterable must be strings.") + return list(addresses) + raise TypeError(f"Unexpected argument type: Received '{type(addresses).__name__}'.") + + +def _send_email( + to: list[str] | Iterable[str], + subject: str, + html_content: str, + files: list[str] | None = None, + dryrun: bool = False, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + conn_id: str | None = None, + custom_headers: dict[str, Any] | None = None, + **kwargs, +) -> None: + """ + Send an email using the backend specified in the *EMAIL_BACKEND* configuration option. + + :param to: A list or iterable of email addresses to send the email to. + :param subject: The subject of the email. + :param html_content: The content of the email in HTML format. + :param files: A list of paths to files to attach to the email. + :param dryrun: If *True*, the email will not actually be sent. Default: *False*. + :param cc: A string or iterable of strings containing email addresses to send a copy of the email to. + :param bcc: A string or iterable of strings containing email addresses to send a + blind carbon copy of the email to. + :param mime_subtype: The subtype of the MIME message. Default: "mixed". + :param mime_charset: The charset of the email. Default: "utf-8". + :param conn_id: The connection ID to use for the backend. If not provided, the default connection + specified in the *EMAIL_CONN_ID* configuration option will be used. + :param custom_headers: A dictionary of additional headers to add to the MIME message. + No validations are run on these values, and they should be able to be encoded. + :param kwargs: Additional keyword arguments to pass to the backend. + """ + backend = conf.getimport("email", "EMAIL_BACKEND") + backend_conn_id = conn_id or conf.get("email", "EMAIL_CONN_ID") + from_email = conf.get("email", "from_email", fallback=None) + + to_list = _get_email_address_list(to) + to_comma_separated = ", ".join(to_list) + + return backend( + to_comma_separated, + subject, + html_content, + files=files, + dryrun=dryrun, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + conn_id=backend_conn_id, + from_email=from_email, + custom_headers=custom_headers, + **kwargs, + ) + + def _send_task_error_email( to: Iterable[str], ti: RuntimeTaskInstance, exception: BaseException | str | None, log: Logger, ) -> None: - from airflow.utils.email import send_email - subject, content, err = _get_email_subject_content(task_instance=ti, exception=exception, log=log) try: - send_email(to, subject, content) + _send_email(to, subject, content) except Exception: - send_email(to, subject, err) + _send_email(to, subject, err) def _execute_task(context: Context, ti: RuntimeTaskInstance, log: Logger): diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 65c236a051a2f..7a3745fbe4be2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1986,6 +1986,44 @@ def test_xcom_clearing_without_keys_to_clear(self, create_runtime_ti, mock_super mock_delete.assert_not_called() + def test_send_email_on_failure(self, create_runtime_ti, mock_supervisor_comms): + """Test that _send_email is called when task fails with email_on_failure=True.""" + + class FailingOperator(BaseOperator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.email_on_failure = True + self.email = ["test@example.com"] + self.max_retries = 0 # No retries, fail immediately + + def execute(self, context): + raise AirflowException("Task failed") + + task = FailingOperator(task_id="failing_task") + runtime_ti = create_runtime_ti(task=task, dag_id="test_email_dag") + runtime_ti._ti_context_from_server = TIRunContext( + dag_run=runtime_ti._ti_context_from_server.dag_run, + task_reschedule_count=0, + max_tries=1, + should_retry=False, + ) + + with mock.patch("airflow.sdk.execution_time.task_runner._send_email") as mock_send_email: + state, error, _ = run(runtime_ti, context=runtime_ti.get_template_context(), log=mock.MagicMock()) + finalize( + runtime_ti, + state, + context=runtime_ti.get_template_context(), + log=mock.MagicMock(), + error=error, + ) + + mock_send_email.assert_called_once() + args, kwargs = mock_send_email.call_args + assert args[0] == ["test@example.com"] + assert args[1] + assert args[2] + class TestXComAfterTaskExecution: @pytest.mark.parametrize(