From 53da449fb61c548e07617ad28a58b61dc322735a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 28 Aug 2025 11:56:16 -0700 Subject: [PATCH 01/10] Add Async support for SMTP Notifier --- providers/smtp/pyproject.toml | 1 + .../src/airflow/providers/smtp/hooks/smtp.py | 320 ++++++++++++++---- .../providers/smtp/notifications/smtp.py | 92 +++-- .../smtp/tests/unit/smtp/hooks/test_smtp.py | 209 +++++++++++- .../unit/smtp/notifications/test_smtp.py | 186 ++++++++++ 5 files changed, 700 insertions(+), 108 deletions(-) diff --git a/providers/smtp/pyproject.toml b/providers/smtp/pyproject.toml index ac31d92cb0f95..98bebb17ff874 100644 --- a/providers/smtp/pyproject.toml +++ b/providers/smtp/pyproject.toml @@ -59,6 +59,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.10.0", "apache-airflow-providers-common-compat>=1.6.1", + "aiosmtplib>=0.1.6", ] [dependency-groups] diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py index 53982137dab51..e9dd2347b9375 100644 --- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py @@ -18,7 +18,8 @@ """ Search in emails for a specific attachment and also to download it. -It uses the smtplib library that is already integrated in python 3. +It uses the smtplib library that is already integrated in python 3 for +synchronous connections or aiosmtplib for async connections. """ from __future__ import annotations @@ -33,9 +34,12 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast +import aiosmtplib + from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.smtp.version_compat import BaseHook @@ -71,15 +75,37 @@ def __init__(self, smtp_conn_id: str = default_conn_name, auth_type: str = "basi super().__init__() self.smtp_conn_id = smtp_conn_id self.smtp_connection: Connection | None = None - self.smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | None = None + self._smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | aiosmtplib.SMTP | None = None self._auth_type = auth_type self._access_token: str | None = None def __enter__(self) -> SmtpHook: return self.get_conn() + async def __aenter__(self) -> SmtpHook: + return await self.aget_conn() + def __exit__(self, exc_type, exc_val, exc_tb): - self.smtp_client.close() + self._smtp_client.close() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._smtp_client: + await self._smtp_client.quit() + + def _setup_oauth2(self) -> tuple[str, str]: + """ + Set up OAuth2 credentials and return token and user identity. + + :return: Tuple of (user_identity, access_token) + """ + if not self._access_token: + self._access_token = self._get_oauth2_token() + + user_identity = self.smtp_user or self.from_email + if user_identity is None: + raise AirflowException("smtp_user or from_email must be set for OAuth2 authentication") + + return user_identity, self._access_token def get_conn(self) -> SmtpHook: """ @@ -90,7 +116,7 @@ def get_conn(self) -> SmtpHook: :return: an authorized SmtpHook object. """ - if not self.smtp_client: + if not self._smtp_client: try: self.smtp_connection = self.get_connection(self.smtp_conn_id) except AirflowNotFoundException: @@ -98,13 +124,13 @@ def get_conn(self) -> SmtpHook: for attempt in range(1, self.smtp_retry_limit + 1): try: - self.smtp_client = self._build_client() + self._smtp_client = self._build_client() except smtplib.SMTPServerDisconnected: if attempt == self.smtp_retry_limit: raise AirflowException("Unable to connect to smtp server") else: if self.smtp_starttls: - self.smtp_client.starttls() + self._smtp_client.starttls() # choose auth if self._auth_type == "oauth2": @@ -115,41 +141,94 @@ def get_conn(self) -> SmtpHook: raise AirflowException( "smtp_user or from_email must be set for OAuth2 authentication" ) - self.smtp_client.auth( + self._smtp_client.auth( "XOAUTH2", lambda _=None: build_xoauth2_string(user_identity, self._access_token), ) elif self.smtp_user and self.smtp_password: - self.smtp_client.login(self.smtp_user, self.smtp_password) + self._smtp_client.login(self.smtp_user, self.smtp_password) break return self - def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP: - SMTP: type[smtplib.SMTP_SSL] | type[smtplib.SMTP] - if self.use_ssl: - SMTP = smtplib.SMTP_SSL - else: - SMTP = smtplib.SMTP + async def aget_conn(self) -> SmtpHook: + """ + Login to the smtp server (async). + + .. note:: Please call this Hook as context manager via `with` + to automatically open and close the connection to the smtp server. + + :return: an authorized SmtpHook object. + """ + if not self._smtp_client: + try: + self.smtp_connection = await self.aget_connection(self.smtp_conn_id) + except AirflowNotFoundException: + raise AirflowException("SMTP connection is not found.") + + for attempt in range(1, self.smtp_retry_limit + 1): + try: + async_client = await self._abuild_client() + self._smtp_client = async_client + except aiosmtplib.errors.SMTPServerDisconnected: + if attempt == self.smtp_retry_limit: + raise AirflowException("Unable to connect to smtp server") + else: + if self.smtp_starttls: + await async_client.starttls() + + if self.smtp_user and self.smtp_password: + await async_client.auth_login(self.smtp_user, self.smtp_password) + break + + return self + + def _build_client_kwargs(self, is_async: bool) -> dict[str, Any]: + """Build kwargs appropriate for sync or async SMTP client.""" + valid_contexts = (None, "default", "none") # Values accepted for ssl_context configuration + + kwargs: dict[str, Any] = {"timeout": self.timeout} - smtp_kwargs: dict[str, Any] = {"host": self.host} if self.port: - smtp_kwargs["port"] = self.port - smtp_kwargs["timeout"] = self.timeout - - if self.use_ssl: - ssl_context_string = self.ssl_context - if ssl_context_string is None or ssl_context_string == "default": - ssl_context = ssl.create_default_context() - elif ssl_context_string == "none": - ssl_context = None - else: - raise RuntimeError( - f"The connection extra field `ssl_context` must " - f"be set to 'default' or 'none' but it is set to '{ssl_context_string}'." - ) - smtp_kwargs["context"] = ssl_context - return SMTP(**smtp_kwargs) + kwargs["port"] = self.port + + if is_async: + kwargs["hostname"] = self.host + kwargs["use_tls"] = self.use_ssl + kwargs["start_tls"] = self.smtp_starttls if not self.use_ssl else None + else: + kwargs["host"] = self.host + if self.use_ssl: + if self.ssl_context not in valid_contexts: + raise RuntimeError( + f"The connection extra field `ssl_context` must " + f"be set to 'default' or 'none' but it is set to '{self.ssl_context}'." + ) + kwargs["context"] = None if self.ssl_context == "none" else ssl.create_default_context() + + return kwargs + + def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP: + """Build a synchronous SMTP client.""" + client: type[smtplib.SMTP_SSL] | type[smtplib.SMTP] = ( + smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP + ) + return client(**self._build_client_kwargs(is_async=False)) + + async def _abuild_client(self) -> aiosmtplib.SMTP: + """ + Build an asynchronous SMTP client. + + Unlike the synchronous client (which connects automatically when instantiated), + aiosmtplib requires explicit connect() and ehlo() calls. We handle those here + to keep the async implementation details contained and make aget_conn behavior + match get_conn. + """ + async_client = aiosmtplib.SMTP(**self._build_client_kwargs(is_async=True)) + await async_client.connect() + await async_client.ehlo() + + return async_client @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: @@ -198,15 +277,82 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: def test_connection(self) -> tuple[bool, str]: """Test SMTP connectivity from UI.""" try: - smtp_client = self.get_conn().smtp_client + smtp_client = self.get_conn()._smtp_client if smtp_client: - status = smtp_client.noop()[0] + status = smtp_client.noop() if status == 250: return True, "Connection successfully tested" except Exception as e: return False, str(e) return False, "Failed to establish connection" + async def atest_connection(self) -> tuple[bool, str]: + """Test SMTP connectivity (async).""" + try: + smtp_client = (await self.aget_conn())._smtp_client + if smtp_client is None: + return False, "SMTP client not initialized" + + if isinstance(smtp_client, aiosmtplib.SMTP): + async_client: aiosmtplib.SMTP = smtp_client + response = await async_client.noop() + if response.code == 250: + return True, "Async connection successfully tested" + return False, f"Connection test failed with code: {response.code}" + + except Exception as e: + return False, str(e) + return False, "Failed to establish connection" + + def _build_message( + self, + to: str | Iterable[str], + subject: str | None = None, + html_content: str | None = None, + from_email: str | None = None, + files: list[str] | None = None, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + custom_headers: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if not self._smtp_client: + raise AirflowException("The 'smtp_client' should be initialized before!") + + from_email = from_email or self.from_email + if not from_email: + raise AirflowException("You should provide `from_email` or define it in the connection.") + + if not subject: + if self.subject_template is None: + raise AirflowException( + "You should provide `subject` or define `subject_template` in the connection." + ) + subject = self._read_template(self.subject_template) + + if not html_content: + if self.html_content_template is None: + raise AirflowException( + "You should provide `html_content` or define `html_content_template` in the connection." + ) + html_content = self._read_template(self.html_content_template) + + mime_msg, recipients = self._build_mime_message( + mail_from=from_email, + to=to, + subject=subject, + html_content=html_content, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + + return {"mime_msg": mime_msg, "recipients": recipients, "from_email": from_email} + def send_email_smtp( self, *, @@ -246,27 +392,9 @@ def send_email_smtp( 'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True ) """ - if not self.smtp_client: - raise AirflowException("The 'smtp_client' should be initialized before!") - from_email = from_email or self.from_email - if not from_email: - raise AirflowException("You should provide `from_email` or define it in the connection.") - if not subject: - if self.subject_template is None: - raise AirflowException( - "You should provide `subject` or define `subject_template` in the connection." - ) - subject = self._read_template(self.subject_template) - if not html_content: - if self.html_content_template is None: - raise AirflowException( - "You should provide `html_content` or define `html_content_template` in the connection." - ) - html_content = self._read_template(self.html_content_template) - - mime_msg, recipients = self._build_mime_message( - mail_from=from_email, + msg = self._build_message( to=to, + from_email=from_email, subject=subject, html_content=html_content, files=files, @@ -277,16 +405,92 @@ def send_email_smtp( custom_headers=custom_headers, ) if not dryrun: + if self._smtp_client is None: + raise AirflowException("The SMTP client is not initialized") + # Casting here to make MyPy happy. + smtp_client = cast("smtplib.SMTP_SSL | smtplib.SMTP", self._smtp_client) + for attempt in range(1, self.smtp_retry_limit + 1): try: - self.smtp_client.sendmail( - from_addr=from_email, to_addrs=recipients, msg=mime_msg.as_string() + smtp_client.sendmail( + from_addr=msg["from_email"], + to_addrs=msg["recipients"], + msg=msg["mime_msg"].as_string(), ) - except smtplib.SMTPServerDisconnected as e: + break + except Exception as e: if attempt == self.smtp_retry_limit: raise e - else: + + async def asend_email_smtp( + self, + *, + to: str | Iterable[str], + subject: str | None = None, + html_content: str | None = None, + from_email: str | None = None, + files: list[str] | None = None, + dryrun: bool = False, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + custom_headers: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """ + Send an email with html content. + + :param to: Recipient email address or list of addresses. + :param subject: Email subject. If it's None, the hook will check if there is a path to a subject + file provided in the connection, and raises an exception if not. + :param html_content: Email body in HTML format. If it's None, the hook will check if there is a path + to a html content file provided in the connection, and raises an exception if not. + :param from_email: Sender email address. If it's None, the hook will check if there is an email + provided in the connection, and raises an exception if not. + :param files: List of file paths to attach to the email. + :param dryrun: If True, the email will not be sent, but all other actions will be performed. + :param cc: Carbon copy recipient email address or list of addresses. + :param bcc: Blind carbon copy recipient email address or list of addresses. + :param mime_subtype: MIME subtype of the email. + :param mime_charset: MIME charset of the email. + :param custom_headers: Dictionary of custom headers to include in the email. + :param kwargs: Additional keyword arguments. + + >>> send_email_smtp( + 'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True + ) + """ + msg = self._build_message( + to=to, + subject=subject, + html_content=html_content, + from_email=from_email, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + if self._smtp_client is None: + raise AirflowException("The SMTP client is not initialized") + # Casting here to make MyPy happy. + smtp_client = cast("aiosmtplib.SMTP", self._smtp_client) + + if not dryrun: + for attempt in range(1, self.smtp_retry_limit + 1): + try: + # The async version of sendmail only supports positional arguments for some reason. + await smtp_client.sendmail( + msg["from_email"], + msg["recipients"], + msg["mime_msg"].as_string(), + ) break + except Exception as e: + if attempt == self.smtp_retry_limit: + raise e def _build_mime_message( self, @@ -420,7 +624,7 @@ def _get_oauth2_token(self) -> str: "auth_type='oauth2' but neither 'access_token' nor client credentials supplied in connection extra." ) - @property + @cached_property def conn(self) -> Connection: if not self.smtp_connection: raise AirflowException("The smtp connection should be loaded before!") diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index 49d463f8819e5..26e57311cb45c 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -20,11 +20,14 @@ from collections.abc import Iterable from functools import cached_property from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.notifier import BaseNotifier from airflow.providers.smtp.hooks.smtp import SmtpHook +if TYPE_CHECKING: + from airflow.sdk import Context + class SmtpNotifier(BaseNotifier): """ @@ -80,8 +83,9 @@ def __init__( auth_type: str = "basic", *, template: str | None = None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.smtp_conn_id = smtp_conn_id self.from_email = from_email self.to = to @@ -106,39 +110,42 @@ def hook(self) -> SmtpHook: """Smtp Events Hook.""" return SmtpHook(smtp_conn_id=self.smtp_conn_id, auth_type=self.auth_type) - def notify(self, context): + def _build_email_content(self, smtp: SmtpHook, context: Context): + fields_to_re_render = [] + if self.from_email is None: + if smtp.from_email is not None: + self.from_email = smtp.from_email + else: + raise ValueError("You should provide `from_email` or define it in the connection") + fields_to_re_render.append("from_email") + if self.subject is None: + smtp_default_templated_subject_path: str + if smtp.subject_template: + smtp_default_templated_subject_path = smtp.subject_template + else: + smtp_default_templated_subject_path = ( + Path(__file__).parent / "templates" / "email_subject.jinja2" + ).as_posix() + self.subject = self._read_template(smtp_default_templated_subject_path) + fields_to_re_render.append("subject") + if self.html_content is None: + smtp_default_templated_html_content_path: str + if smtp.html_content_template: + smtp_default_templated_html_content_path = smtp.html_content_template + else: + smtp_default_templated_html_content_path = ( + Path(__file__).parent / "templates" / "email.html" + ).as_posix() + self.html_content = self._read_template(smtp_default_templated_html_content_path) + fields_to_re_render.append("html_content") + if fields_to_re_render: + jinja_env = self.get_template_env(dag=context["dag"]) + self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) + + def notify(self, context: Context): """Send a email via smtp server.""" with self.hook as smtp: - fields_to_re_render = [] - if self.from_email is None: - if smtp.from_email is not None: - self.from_email = smtp.from_email - else: - raise ValueError("You should provide `from_email` or define it in the connection") - fields_to_re_render.append("from_email") - if self.subject is None: - smtp_default_templated_subject_path: str - if smtp.subject_template: - smtp_default_templated_subject_path = smtp.subject_template - else: - smtp_default_templated_subject_path = ( - Path(__file__).parent / "templates" / "email_subject.jinja2" - ).as_posix() - self.subject = self._read_template(smtp_default_templated_subject_path) - fields_to_re_render.append("subject") - if self.html_content is None: - smtp_default_templated_html_content_path: str - if smtp.html_content_template: - smtp_default_templated_html_content_path = smtp.html_content_template - else: - smtp_default_templated_html_content_path = ( - Path(__file__).parent / "templates" / "email.html" - ).as_posix() - self.html_content = self._read_template(smtp_default_templated_html_content_path) - fields_to_re_render.append("html_content") - if fields_to_re_render: - jinja_env = self.get_template_env(dag=context["dag"]) - self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) + self._build_email_content(smtp, context) smtp.send_email_smtp( smtp_conn_id=self.smtp_conn_id, from_email=self.from_email, @@ -153,5 +160,24 @@ def notify(self, context): custom_headers=self.custom_headers, ) + async def async_notify(self, context: Context): + """Send a email via smtp server (async).""" + async with self.hook as smtp: + # TODO: Context is not yet available, uncomment this once implemented + # self._build_email_content(smtp, context) + await smtp.asend_email_smtp( + smtp_conn_id=self.smtp_conn_id, + from_email=self.from_email, + to=self.to, + subject=self.subject, + html_content=self.html_content, + files=self.files, + cc=self.cc, + bcc=self.bcc, + mime_subtype=self.mime_subtype, + mime_charset=self.mime_charset, + custom_headers=self.custom_headers, + ) + send_smtp_notification = SmtpNotifier diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index fc09281a306fc..9e8b54a8ecb09 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -22,8 +22,10 @@ import smtplib import tempfile from email.mime.application import MIMEApplication -from unittest.mock import Mock, call, patch +from unittest import mock +from unittest.mock import AsyncMock, Mock, call, patch +import aiosmtplib import pytest from airflow.exceptions import AirflowException @@ -85,28 +87,39 @@ def setup_connections(self, create_connection_without_db): ) ) + @pytest.mark.parametrize( + "conn_id, use_ssl, expected_port, create_context", + [ + pytest.param("smtp_default", True, 465, True, id="ssl-connection"), + pytest.param("smtp_nonssl", False, 587, False, id="non-ssl-connection"), + ], + ) @patch(smtplib_string) @patch("ssl.create_default_context") - def test_connect_and_disconnect(self, create_default_context, mock_smtplib): - mock_conn = _create_fake_smtp(mock_smtplib) + def test_connect_and_disconnect( + self, create_default_context, mock_smtplib, conn_id, use_ssl, expected_port, create_context + ): + """Test sync connection with different configurations.""" + mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=use_ssl) - with SmtpHook(): + with SmtpHook(smtp_conn_id=conn_id): pass - assert create_default_context.called - mock_smtplib.SMTP_SSL.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value - ) - mock_conn.login.assert_called_once_with("smtp_user", "smtp_password") - assert mock_conn.close.call_count == 1 - - @patch(smtplib_string) - def test_connect_and_disconnect_via_nonssl(self, mock_smtplib): - mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False) - with SmtpHook(smtp_conn_id="smtp_nonssl"): - pass + if create_context: + assert create_default_context.called + mock_smtplib.SMTP_SSL.assert_called_once_with( + host="smtp_server_address", + port=expected_port, + timeout=30, + context=create_default_context.return_value, + ) + else: + mock_smtplib.SMTP.assert_called_once_with( + host="smtp_server_address", + port=expected_port, + timeout=30, + ) - mock_smtplib.SMTP.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) mock_conn.login.assert_called_once_with("smtp_user", "smtp_password") assert mock_conn.close.call_count == 1 @@ -432,3 +445,165 @@ def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_witho ) assert not mock_conn.auth.called + + +@pytest.mark.asyncio +class TestSmtpHookAsync: + """Tests for async functionality in SmtpHook.""" + + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id="smtp_default", + conn_type="smtp", + host="smtp_server_address", + login="smtp_user", + password="smtp_password", + port=465, + extra=json.dumps(dict(from_email="from", ssl_context="default")), + ) + ) + create_connection_without_db( + Connection( + conn_id="smtp_nonssl", + conn_type="smtp", + host="smtp_server_address", + login="smtp_user", + password="smtp_password", + port=587, + extra=json.dumps(dict(disable_ssl=True, from_email="from")), + ) + ) + + @pytest.fixture + def mock_smtp_client(self): + """Create a mock SMTP client with async capabilities.""" + mock_client = AsyncMock(spec=aiosmtplib.SMTP) + mock_client.starttls = AsyncMock() + mock_client.auth_login = AsyncMock() + mock_client.sendmail = AsyncMock() + mock_client.quit = AsyncMock() + return mock_client + + @pytest.fixture + def mock_smtp(self, mock_smtp_client): + """Set up the SMTP mock with context manager.""" + with mock.patch("airflow.providers.smtp.hooks.smtp.aiosmtplib.SMTP") as mock_smtp: + mock_smtp.return_value = mock_smtp_client + yield mock_smtp + + @pytest.fixture + def mock_get_connection(self): + """Mock the async connection retrieval.""" + with mock.patch("airflow.sdk.bases.hook.BaseHook.aget_connection") as mock_conn: + + async def async_get_connection(conn_id): + from airflow.sdk.definitions.connection import Connection + + return Connection.from_json(os.environ[f"AIRFLOW_CONN_{conn_id.upper()}"]) + + mock_conn.side_effect = async_get_connection + yield mock_conn + + @staticmethod + def _create_fake_async_smtp(mock_smtp): + mock_client = AsyncMock(spec=aiosmtplib.SMTP) + mock_client.starttls = AsyncMock() + mock_client.auth_login = AsyncMock() + mock_client.sendmail = AsyncMock() + mock_client.quit = AsyncMock() + mock_smtp.return_value = mock_client + return mock_client + + @pytest.mark.parametrize( + "conn_id, expected_port, expected_ssl", + [ + pytest.param("smtp_nonssl", 587, False, id="non-ssl-connection"), + pytest.param("smtp_default", 465, True, id="ssl-connection"), + ], + ) + async def test_async_connection( + self, mock_smtp, mock_smtp_client, mock_get_connection, conn_id, expected_port, expected_ssl + ): + """Test async connection with different configurations.""" + async with SmtpHook(smtp_conn_id=conn_id) as hook: + assert hook is not None + + mock_smtp.assert_called_once_with( + hostname="smtp_server_address", + port=expected_port, + timeout=30, + use_tls=expected_ssl, + start_tls=None if expected_ssl else True, + ) + + if expected_ssl: + assert mock_smtp_client.starttls.await_count == 1 + + assert mock_smtp_client.auth_login.await_count == 1 + mock_smtp_client.auth_login.assert_awaited_once_with("smtp_user", "smtp_password") + + @pytest.mark.asyncio + async def test_async_send_email(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending functionality.""" + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to="to@example.com", + subject="test subject", + html_content="test content", + ) + + assert mock_smtp_client.sendmail.called + # The async version of sendmail only supports positional arguments + # for some reason, so we have to check these by positional args + call_args = mock_smtp_client.sendmail.await_args.args + assert call_args[0] == "from" # sender is first positional arg + assert call_args[1] == ["to@example.com"] # recipients is the second positional arg + assert "Subject: test subject" in call_args[2] # message is the third positional arg + + @pytest.mark.asyncio + async def test_async_send_email_with_retries(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending with connection retries.""" + mock_smtp_client.sendmail.side_effect = [ + aiosmtplib.errors.SMTPServerDisconnected("Server disconnected"), + aiosmtplib.errors.SMTPServerDisconnected("Server disconnected"), + None, # Success on third try + ] + + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to="to@example.com", + subject="test subject", + html_content="test content", + ) + + assert mock_smtp_client.sendmail.await_count == 3 + + async def test_async_send_email_max_retries(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending with max retries exceeded.""" + mock_smtp_client.sendmail.side_effect = aiosmtplib.errors.SMTPServerDisconnected( + "Server disconnected" + ) + + with pytest.raises(aiosmtplib.errors.SMTPServerDisconnected): + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to="to@example.com", + subject="test subject", + html_content="test content", + ) + + assert mock_smtp_client.sendmail.await_count == 5 # Default retry limit + + async def test_async_send_email_dryrun(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending in dryrun mode.""" + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to="to@example.com", + subject="test subject", + html_content="test content", + dryrun=True, + ) + + mock_smtp_client.sendmail.assert_not_awaited() diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 4dfffab36e03e..c1624586ad0de 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -19,6 +19,9 @@ import tempfile from unittest import mock +from unittest.mock import AsyncMock + +import pytest from airflow.providers.smtp.hooks.smtp import SmtpHook from airflow.providers.smtp.notifications.smtp import ( @@ -200,3 +203,186 @@ def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_w smtp_conn_id="smtp_default", auth_type="oauth2", ) + + +class TestSmtpNotifierAsync: + @pytest.fixture + def mock_smtp_client(self): + """Create a mock SMTP object with async capabilities.""" + mock_smtp = AsyncMock() + mock_smtp.asend_email_smtp = AsyncMock() + return mock_smtp + + @pytest.fixture + def mock_smtp_hook(self, mock_smtp_client): + """Set up the SMTP hook with async context manager.""" + with mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") as mock_hook: + mock_hook.return_value.__aenter__ = AsyncMock(return_value=mock_smtp_client) + yield mock_hook + + @pytest.mark.asyncio + async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): + notifier = SmtpNotifier( + from_email="test_sender@test.com", + to="test_reciver@test.com", + subject="subject", + html_content="body", + ) + await notifier.async_notify({"dag": create_dag_without_db("test_notifier")}) + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id="smtp_default", + from_email="test_sender@test.com", + to="test_reciver@test.com", + subject="subject", + html_content="body", + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) + + @pytest.mark.asyncio + async def test_async_notifier_with_notifier_class( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db + ): + notifier = SmtpNotifier( + from_email="test_sender@test.com", + to="test_reciver@test.com", + subject="subject", + html_content="body", + context={"dag": create_dag_without_db("test_notifier")}, + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id="smtp_default", + from_email="test_sender@test.com", + to="test_reciver@test.com", + subject="subject", + html_content="body", + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) + + @pytest.mark.asyncio + async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): + notifier = SmtpNotifier( + from_email="test_sender@test.com {{dag.dag_id}}", + to="test_reciver@test.com {{dag.dag_id}}", + subject="subject {{dag.dag_id}}", + html_content="body {{dag.dag_id}}", + context={"dag": create_dag_without_db("test_notifier")}, + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id="smtp_default", + from_email="test_sender@test.com test_notifier", + to="test_reciver@test.com test_notifier", + subject="subject test_notifier", + html_content="body test_notifier", + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) + + @pytest.mark.asyncio + async def test_async_notifier_with_defaults( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance + ): + mock_smtp_client.subject_template = None + mock_smtp_client.html_content_template = None + mock_smtp_client.from_email = None + + mock_ti = mock_task_instance( + dag_id="test_dag", + task_id="op", + run_id="test", + try_number=NUM_TRY, + max_tries=0, + state=None, + ) + + notifier = SmtpNotifier( + from_email="any email", + to="test_reciver@test.com", + context={"dag": create_dag_without_db("test_dag"), "ti": mock_ti}, + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id="smtp_default", + from_email="any email", + to="test_reciver@test.com", + subject="DAG test_dag - Task op - Run ID test in State None", + html_content=mock.ANY, + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) + content = mock_smtp_client.asend_email_smtp.call_args.kwargs["html_content"] + assert f"{NUM_TRY} of 1" in content + + @pytest.mark.asyncio + async def test_async_notifier_with_nondefault_connection_extra( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance + ): + ti = mock_task_instance( + dag_id="test_dag", + task_id="op", + run_id="test_run", + try_number=NUM_TRY, + max_tries=0, + state=None, + ) + + with ( + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, + ): + f_subject.write("Task {{ ti.task_id }} failed") + f_subject.flush() + + f_content.write("Mock content goes here") + f_content.flush() + + mock_smtp_client.from_email = "{{ ti.task_id }}@test.com" + mock_smtp_client.subject_template = f_subject.name + mock_smtp_client.html_content_template = f_content.name + + notifier = SmtpNotifier( + to="test_reciver@test.com", + context={"dag": create_dag_without_db("test_dag"), "ti": ti}, + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id="smtp_default", + from_email="op@test.com", + to="test_reciver@test.com", + subject="Task op failed", + html_content="Mock content goes here", + files=None, + cc=None, + bcc=None, + mime_subtype="mixed", + mime_charset="utf-8", + custom_headers=None, + ) From 45c1aa5fa6bf474b27c7befe74ab95cabef1fe9e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Sep 2025 14:01:52 -0700 Subject: [PATCH 02/10] Fix tests and remove context form async support for now --- .../providers/smtp/notifications/smtp.py | 11 ++++-- .../unit/smtp/notifications/test_smtp.py | 38 ++++--------------- 2 files changed, 15 insertions(+), 34 deletions(-) diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index 26e57311cb45c..c6a00ef90d981 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -110,7 +110,9 @@ def hook(self) -> SmtpHook: """Smtp Events Hook.""" return SmtpHook(smtp_conn_id=self.smtp_conn_id, auth_type=self.auth_type) - def _build_email_content(self, smtp: SmtpHook, context: Context): + def _build_email_content(self, smtp: SmtpHook, context: Context, use_templates: bool = True): + # TODO: use_templates is temporary until templating on the Triggerer is sorted out. + fields_to_re_render = [] if self.from_email is None: if smtp.from_email is not None: @@ -138,7 +140,7 @@ def _build_email_content(self, smtp: SmtpHook, context: Context): ).as_posix() self.html_content = self._read_template(smtp_default_templated_html_content_path) fields_to_re_render.append("html_content") - if fields_to_re_render: + if fields_to_re_render and use_templates: jinja_env = self.get_template_env(dag=context["dag"]) self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) @@ -163,8 +165,9 @@ def notify(self, context: Context): async def async_notify(self, context: Context): """Send a email via smtp server (async).""" async with self.hook as smtp: - # TODO: Context is not yet available, uncomment this once implemented - # self._build_email_content(smtp, context) + # TODO: use_templates is temporary until templating on the Triggerer is sorted out. + # Once that iks done, we can remove that flag. + self._build_email_content(smtp, context, use_templates=False) await smtp.asend_email_smtp( smtp_conn_id=self.smtp_conn_id, from_email=self.from_email, diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index c1624586ad0de..7eee8cbd701e2 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -306,19 +306,11 @@ async def test_async_notifier_with_defaults( mock_smtp_client.html_content_template = None mock_smtp_client.from_email = None - mock_ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test", - try_number=NUM_TRY, - max_tries=0, - state=None, - ) - notifier = SmtpNotifier( from_email="any email", to="test_reciver@test.com", - context={"dag": create_dag_without_db("test_dag"), "ti": mock_ti}, + subject="subject", + html_content="body", ) await notifier @@ -327,7 +319,7 @@ async def test_async_notifier_with_defaults( smtp_conn_id="smtp_default", from_email="any email", to="test_reciver@test.com", - subject="DAG test_dag - Task op - Run ID test in State None", + subject="subject", html_content=mock.ANY, files=None, cc=None, @@ -336,48 +328,34 @@ async def test_async_notifier_with_defaults( mime_charset="utf-8", custom_headers=None, ) - content = mock_smtp_client.asend_email_smtp.call_args.kwargs["html_content"] - assert f"{NUM_TRY} of 1" in content @pytest.mark.asyncio async def test_async_notifier_with_nondefault_connection_extra( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance ): - ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test_run", - try_number=NUM_TRY, - max_tries=0, - state=None, - ) - with ( tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, ): - f_subject.write("Task {{ ti.task_id }} failed") + f_subject.write("Connection Default Subject") f_subject.flush() f_content.write("Mock content goes here") f_content.flush() - mock_smtp_client.from_email = "{{ ti.task_id }}@test.com" + mock_smtp_client.from_email = "connection_default@test.com" mock_smtp_client.subject_template = f_subject.name mock_smtp_client.html_content_template = f_content.name - notifier = SmtpNotifier( - to="test_reciver@test.com", - context={"dag": create_dag_without_db("test_dag"), "ti": ti}, - ) + notifier = SmtpNotifier(to="test_reciver@test.com") await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( smtp_conn_id="smtp_default", - from_email="op@test.com", + from_email="connection_default@test.com", to="test_reciver@test.com", - subject="Task op failed", + subject="Connection Default Subject", html_content="Mock content goes here", files=None, cc=None, From 3bd94f5a7de1d8267d6cb91e0dbaed09102abda0 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Sep 2025 16:29:18 -0700 Subject: [PATCH 03/10] overengineer the tests to remove magic strings --- .../smtp/tests/unit/smtp/hooks/test_smtp.py | 506 ++++++++++-------- .../unit/smtp/notifications/test_smtp.py | 359 +++++++------ 2 files changed, 452 insertions(+), 413 deletions(-) diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index 9e8b54a8ecb09..397c1412d2766 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -34,7 +34,31 @@ smtplib_string = "airflow.providers.smtp.hooks.smtp.smtplib" -TEST_EMAILS = ["test1@example.com", "test2@example.com"] +FROM_EMAIL = "from@example.com" +TO_EMAIL = "to@example.com" +TEST_EMAILS = [FROM_EMAIL, TO_EMAIL] + +CONN_TYPE = "smtp" +SMTP_HOST = "smtp.example.com" +SMTP_LOGIN = "smtp_user" +SMTP_PASSWORD = "smtp_password" +ACCESS_TOKEN = "test-token" + +CONN_ID_DEFAULT = "smtp_default" +CONN_ID_NONSSL = "smtp_nonssl" +CONN_ID_SSL_EXTRA = "smtp_ssl_extra" +CONN_ID_OAUTH = "smtp_oauth2" + +DEFAULT_PORT = 465 +NONSSL_PORT = 587 + +DEFAULT_TIMEOUT = 30 +DEFAULT_RETRY_LIMIT = 5 + +TEST_SUBJECT = "test subject" +TEST_BODY = "Test" + +SERVER_DISCONNECTED_ERROR = aiosmtplib.errors.SMTPServerDisconnected("Server disconnected") def _create_fake_smtp(mock_smtplib, use_ssl=True): @@ -55,43 +79,54 @@ class TestSmtpHook: def setup_connections(self, create_connection_without_db): create_connection_without_db( Connection( - conn_id="smtp_default", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps(dict(from_email="from")), + conn_id=CONN_ID_DEFAULT, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps(dict(from_email=FROM_EMAIL)), ) ) create_connection_without_db( Connection( - conn_id="smtp_nonssl", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + conn_id=CONN_ID_NONSSL, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) ) create_connection_without_db( Connection( - conn_id="smtp_oauth2", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from", access_token="test-token")), + conn_id=CONN_ID_OAUTH, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL, access_token=ACCESS_TOKEN)), + ) + ) + create_connection_without_db( + Connection( + conn_id=CONN_ID_SSL_EXTRA, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=None, + password="None", + port=DEFAULT_PORT, + extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email=FROM_EMAIL)), ) ) @pytest.mark.parametrize( "conn_id, use_ssl, expected_port, create_context", [ - pytest.param("smtp_default", True, 465, True, id="ssl-connection"), - pytest.param("smtp_nonssl", False, 587, False, id="non-ssl-connection"), + pytest.param(CONN_ID_DEFAULT, True, DEFAULT_PORT, True, id="ssl-connection"), + pytest.param(CONN_ID_NONSSL, False, NONSSL_PORT, False, id="non-ssl-connection"), ], ) @patch(smtplib_string) @@ -108,68 +143,59 @@ def test_connect_and_disconnect( if create_context: assert create_default_context.called mock_smtplib.SMTP_SSL.assert_called_once_with( - host="smtp_server_address", + host=SMTP_HOST, port=expected_port, - timeout=30, + timeout=DEFAULT_TIMEOUT, context=create_default_context.return_value, ) else: mock_smtplib.SMTP.assert_called_once_with( - host="smtp_server_address", + host=SMTP_HOST, port=expected_port, - timeout=30, + timeout=DEFAULT_TIMEOUT, ) - mock_conn.login.assert_called_once_with("smtp_user", "smtp_password") + mock_conn.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD) assert mock_conn.close.call_count == 1 @patch(smtplib_string) def test_get_email_address_single_email(self, mock_smtplib): with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com") == ["test1@example.com"] - - @patch(smtplib_string) - def test_get_email_address_comma_sep_string(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com, test2@example.com") == TEST_EMAILS - - @patch(smtplib_string) - def test_get_email_address_colon_sep_string(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com; test2@example.com") == TEST_EMAILS - - @patch(smtplib_string) - def test_get_email_address_list(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert ( - smtp_hook._get_email_address_list(["test1@example.com", "test2@example.com"]) == TEST_EMAILS - ) + assert smtp_hook._get_email_address_list(FROM_EMAIL) == [FROM_EMAIL] + @pytest.mark.parametrize( + "email_input", + [ + pytest.param(f"{FROM_EMAIL}, {TO_EMAIL}", id="comma_separated"), + pytest.param(f"{FROM_EMAIL}; {TO_EMAIL}", id="semicolon_separated"), + pytest.param([FROM_EMAIL, TO_EMAIL], id="list_input"), + pytest.param((FROM_EMAIL, TO_EMAIL), id="tuple_input"), + ], + ) @patch(smtplib_string) - def test_get_email_address_tuple(self, mock_smtplib): + def test_get_email_address_parsing(self, mock_smtplib, email_input): with SmtpHook() as smtp_hook: - assert ( - smtp_hook._get_email_address_list(("test1@example.com", "test2@example.com")) == TEST_EMAILS - ) - - @patch(smtplib_string) - def test_get_email_address_invalid_type(self, mock_smtplib): - with pytest.raises(TypeError): - with SmtpHook() as smtp_hook: - smtp_hook._get_email_address_list(1) + assert smtp_hook._get_email_address_list(email_input) == TEST_EMAILS + @pytest.mark.parametrize( + "invalid_input", + [ + pytest.param(1, id="invalid_scalar_type"), + pytest.param([FROM_EMAIL, 2], id="invalid_type_in_list"), + ], + ) @patch(smtplib_string) - def test_get_email_address_invalid_type_in_iterable(self, mock_smtplib): + def test_get_email_address_invalid_types(self, mock_smtplib, invalid_input): with pytest.raises(TypeError): with SmtpHook() as smtp_hook: - smtp_hook._get_email_address_list(["test1@example.com", 2]) + smtp_hook._get_email_address_list(invalid_input) @patch(smtplib_string) def test_build_mime_message(self, mock_smtplib): - mail_from = "from@example.com" - mail_to = "to@example.com" - subject = "test subject" - html_content = "Test" + mail_from = FROM_EMAIL + mail_to = TO_EMAIL + subject = TEST_SUBJECT + html_content = TEST_BODY custom_headers = {"Reply-To": "reply_to@example.com"} with SmtpHook() as smtp_hook: msg, recipients = smtp_hook._build_mime_message( @@ -194,15 +220,15 @@ def test_send_smtp(self, mock_smtplib): attachment.write(b"attachment") attachment.seek(0) smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", files=[attachment.name] + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, files=[attachment.name] ) assert mock_send_mime.called _, call_args = mock_send_mime.call_args - assert call_args["from_addr"] == "from" - assert call_args["to_addrs"] == ["to"] + assert call_args["from_addr"] == FROM_EMAIL + assert call_args["to_addrs"] == [TO_EMAIL] msg = call_args["msg"] - assert "Subject: subject" in msg - assert "From: from" in msg + assert f"Subject: {TEST_SUBJECT}" in msg + assert f"From: {FROM_EMAIL}" in msg filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"' assert filename in msg mimeapp = MIMEApplication("attachment") @@ -212,148 +238,150 @@ def test_send_smtp(self, mock_smtplib): @patch(smtplib_string) def test_hook_conn(self, mock_smtplib, mock_hook_conn): mock_conn = Mock() - mock_conn.login = "user" - mock_conn.password = "password" - mock_conn.extra_dejson = { - "disable_ssl": False, - } + mock_conn.login = SMTP_LOGIN + mock_conn.password = SMTP_PASSWORD + mock_conn.extra_dejson = {"disable_ssl": False} mock_hook_conn.return_value = mock_conn smtp_client_mock = mock_smtplib.SMTP_SSL() with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - mock_hook_conn.assert_called_with("smtp_default") - smtp_client_mock.login.assert_called_once_with("user", "password") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + + mock_hook_conn.assert_called_with(CONN_ID_DEFAULT) + smtp_client_mock.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD) smtp_client_mock.sendmail.assert_called_once() assert smtp_client_mock.close.called + @pytest.mark.parametrize( + "conn_id, ssl_context, create_context_called, use_default_context", + [ + pytest.param(CONN_ID_DEFAULT, "default", True, True, id="default_context"), + pytest.param(CONN_ID_SSL_EXTRA, "none", False, False, id="none_context"), + pytest.param(CONN_ID_DEFAULT, "default", True, True, id="explicit_default_context"), + ], + ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl): - mock_smtp_ssl.return_value = Mock() - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - assert not mock_smtp.called - assert create_default_context.called - mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value - ) - - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_extra_none_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db + def test_send_mime_ssl_context( + self, + create_default_context, + mock_smtp, + mock_smtp_ssl, + conn_id, + ssl_context, + create_context_called, + use_default_context, ): mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email="from")), - ) - create_connection_without_db(conn) - with SmtpHook(smtp_conn_id="smtp_ssl_extra") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - assert not mock_smtp.called - create_default_context.assert_not_called() - mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_extra_default_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db - ): - mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, ssl_context="default", from_email="from")), - ) - create_connection_without_db(conn) - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + with SmtpHook(conn_id) as smtp_hook: + smtp_hook.send_email_smtp( + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL + ) + assert not mock_smtp.called - assert create_default_context.called + if use_default_context: + assert create_default_context.called + expected_context = create_default_context.return_value + else: + create_default_context.assert_not_called() + expected_context = None + mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value + host=SMTP_HOST, port=DEFAULT_PORT, timeout=DEFAULT_TIMEOUT, context=expected_context ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_default_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db - ): + def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, from_email="from")), - ) - create_connection_without_db(conn) + with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + assert not mock_smtp.called assert create_default_context.called mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value + host=SMTP_HOST, + port=DEFAULT_PORT, + timeout=DEFAULT_TIMEOUT, + context=create_default_context.return_value, ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") def test_send_mime_nossl(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = Mock() - with SmtpHook(smtp_conn_id="smtp_nonssl") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + + with SmtpHook(smtp_conn_id=CONN_ID_NONSSL) as smtp_hook: + smtp_hook.send_email_smtp( + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL + ) + assert not mock_smtp_ssl.called - mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) + mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT) @patch("smtplib.SMTP") def test_send_mime_noauth(self, mock_smtp, create_connection_without_db): mock_smtp.return_value = Mock() conn = Connection( conn_id="smtp_noauth", - conn_type="smtp", - host="smtp_server_address", + conn_type=CONN_TYPE, + host=SMTP_HOST, login=None, password="None", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) create_connection_without_db(conn) with SmtpHook(smtp_conn_id="smtp_noauth") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT) assert not mock_smtp.login.called @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", dryrun=True) + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + dryrun=True, + ) + assert not mock_smtp.sendmail.called assert not mock_smtp_ssl.sendmail.called @patch("smtplib.SMTP_SSL") def test_send_mime_ssl_complete_failure(self, mock_smtp_ssl): mock_smtp_ssl().sendmail.side_effect = smtplib.SMTPServerDisconnected() + with SmtpHook() as smtp_hook: with pytest.raises(smtplib.SMTPServerDisconnected): - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") - assert mock_smtp_ssl().sendmail.call_count == 5 + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + + assert mock_smtp_ssl().sendmail.call_count == DEFAULT_RETRY_LIMIT @patch("email.message.Message.as_string") @patch("smtplib.SMTP_SSL") @@ -363,10 +391,14 @@ def test_send_mime_partial_failure(self, mock_smtp_ssl, mime_message_mock): side_effects = [smtplib.SMTPServerDisconnected(), smtplib.SMTPServerDisconnected(), final_mock] mock_smtp_ssl.side_effect = side_effects with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) assert mock_smtp_ssl.call_count == side_effects.index(final_mock) + 1 assert final_mock.starttls.called - final_mock.sendmail.assert_called_once_with(from_addr="from", to_addrs=["to"], msg="msg") + final_mock.sendmail.assert_called_once_with(from_addr=FROM_EMAIL, to_addrs=[TO_EMAIL], msg="msg") assert final_mock.close.called @patch("smtplib.SMTP_SSL") @@ -379,18 +411,20 @@ def test_send_mime_custom_timeout_retrylimit( custom_timeout = 60 fake_conn = Connection( conn_id="mock_conn", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps(dict(from_email="from", timeout=custom_timeout, retry_limit=custom_retry_limit)), + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps( + dict(from_email=FROM_EMAIL, timeout=custom_timeout, retry_limit=custom_retry_limit) + ), ) create_connection_without_db(fake_conn) with SmtpHook(smtp_conn_id="mock_conn") as smtp_hook: with pytest.raises(smtplib.SMTPServerDisconnected): - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") + smtp_hook.send_email_smtp(to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY) expected_call = call( host=fake_conn.host, @@ -406,18 +440,18 @@ def test_send_mime_custom_timeout_retrylimit( def test_oauth2_auth_called(self, mock_smtplib): mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False) - with SmtpHook(smtp_conn_id="smtp_oauth2", auth_type="oauth2") as smtp_hook: + with SmtpHook(smtp_conn_id=CONN_ID_OAUTH, auth_type="oauth2") as smtp_hook: smtp_hook.send_email_smtp( - to="to@example.com", - subject="subject", - html_content="content", - from_email="from", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, ) assert mock_conn.auth.called args, _ = mock_conn.auth.call_args assert args[0] == "XOAUTH2" - assert build_xoauth2_string("smtp_user", "test-token") == args[1]() + assert build_xoauth2_string(SMTP_LOGIN, ACCESS_TOKEN) == args[1]() @patch(smtplib_string) def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_without_db): @@ -426,22 +460,22 @@ def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_witho create_connection_without_db( Connection( conn_id="smtp_oauth2_empty", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) ) with pytest.raises(AirflowException): with SmtpHook(smtp_conn_id="smtp_oauth2_empty", auth_type="oauth2") as h: h.send_email_smtp( - to="to@example.com", - subject="subject", - html_content="content", - from_email="from", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, ) assert not mock_conn.auth.called @@ -455,24 +489,24 @@ class TestSmtpHookAsync: def setup_connections(self, create_connection_without_db): create_connection_without_db( Connection( - conn_id="smtp_default", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps(dict(from_email="from", ssl_context="default")), + conn_id=CONN_ID_DEFAULT, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps(dict(from_email=FROM_EMAIL, ssl_context="default")), ) ) create_connection_without_db( Connection( - conn_id="smtp_nonssl", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + conn_id=CONN_ID_NONSSL, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) ) @@ -519,8 +553,8 @@ def _create_fake_async_smtp(mock_smtp): @pytest.mark.parametrize( "conn_id, expected_port, expected_ssl", [ - pytest.param("smtp_nonssl", 587, False, id="non-ssl-connection"), - pytest.param("smtp_default", 465, True, id="ssl-connection"), + pytest.param(CONN_ID_NONSSL, NONSSL_PORT, False, id="non-ssl-connection"), + pytest.param(CONN_ID_DEFAULT, DEFAULT_PORT, True, id="ssl-connection"), ], ) async def test_async_connection( @@ -531,9 +565,9 @@ async def test_async_connection( assert hook is not None mock_smtp.assert_called_once_with( - hostname="smtp_server_address", + hostname=SMTP_HOST, port=expected_port, - timeout=30, + timeout=DEFAULT_TIMEOUT, use_tls=expected_ssl, start_tls=None if expected_ssl else True, ) @@ -542,67 +576,69 @@ async def test_async_connection( assert mock_smtp_client.starttls.await_count == 1 assert mock_smtp_client.auth_login.await_count == 1 - mock_smtp_client.auth_login.assert_awaited_once_with("smtp_user", "smtp_password") + mock_smtp_client.auth_login.assert_awaited_once_with(SMTP_LOGIN, SMTP_PASSWORD) @pytest.mark.asyncio async def test_async_send_email(self, mock_smtp, mock_smtp_client, mock_get_connection): """Test async email sending functionality.""" async with SmtpHook() as hook: await hook.asend_email_smtp( - to="to@example.com", - subject="test subject", - html_content="test content", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) assert mock_smtp_client.sendmail.called # The async version of sendmail only supports positional arguments # for some reason, so we have to check these by positional args call_args = mock_smtp_client.sendmail.await_args.args - assert call_args[0] == "from" # sender is first positional arg - assert call_args[1] == ["to@example.com"] # recipients is the second positional arg - assert "Subject: test subject" in call_args[2] # message is the third positional arg + assert call_args[0] == FROM_EMAIL # sender is first positional arg + assert call_args[1] == [TO_EMAIL] # recipients is the second positional arg + assert f"Subject: {TEST_SUBJECT}" in call_args[2] # message is the third positional arg + @pytest.mark.parametrize( + "side_effect, expected_calls, should_raise", + [ + pytest.param( + [SERVER_DISCONNECTED_ERROR, SERVER_DISCONNECTED_ERROR, None], + 3, + False, + id="success_after_retries", + ), + pytest.param(SERVER_DISCONNECTED_ERROR, DEFAULT_RETRY_LIMIT, True, id="max_retries_exceeded"), + ], + ) @pytest.mark.asyncio - async def test_async_send_email_with_retries(self, mock_smtp, mock_smtp_client, mock_get_connection): - """Test async email sending with connection retries.""" - mock_smtp_client.sendmail.side_effect = [ - aiosmtplib.errors.SMTPServerDisconnected("Server disconnected"), - aiosmtplib.errors.SMTPServerDisconnected("Server disconnected"), - None, # Success on third try - ] - - async with SmtpHook() as hook: - await hook.asend_email_smtp( - to="to@example.com", - subject="test subject", - html_content="test content", - ) - - assert mock_smtp_client.sendmail.await_count == 3 - - async def test_async_send_email_max_retries(self, mock_smtp, mock_smtp_client, mock_get_connection): - """Test async email sending with max retries exceeded.""" - mock_smtp_client.sendmail.side_effect = aiosmtplib.errors.SMTPServerDisconnected( - "Server disconnected" - ) - - with pytest.raises(aiosmtplib.errors.SMTPServerDisconnected): + async def test_async_send_email_retries( + self, mock_smtp, mock_smtp_client, mock_get_connection, side_effect, expected_calls, should_raise + ): + mock_smtp_client.sendmail.side_effect = side_effect + + if should_raise: + with pytest.raises(aiosmtplib.errors.SMTPServerDisconnected): + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + else: async with SmtpHook() as hook: await hook.asend_email_smtp( - to="to@example.com", - subject="test subject", - html_content="test content", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) - assert mock_smtp_client.sendmail.await_count == 5 # Default retry limit + assert mock_smtp_client.sendmail.call_count == expected_calls async def test_async_send_email_dryrun(self, mock_smtp, mock_smtp_client, mock_get_connection): """Test async email sending in dryrun mode.""" async with SmtpHook() as hook: await hook.asend_email_smtp( - to="to@example.com", - subject="test subject", - html_content="test content", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, dryrun=True, ) diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 7eee8cbd701e2..33b1b1dd75617 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -18,129 +18,162 @@ from __future__ import annotations import tempfile +from dataclasses import dataclass from unittest import mock from unittest.mock import AsyncMock import pytest -from airflow.providers.smtp.hooks.smtp import SmtpHook from airflow.providers.smtp.notifications.smtp import ( SmtpNotifier, send_smtp_notification, ) -SMTP_API_DEFAULT_CONN_ID = SmtpHook.default_conn_name +TRY_NUMBER = 0 + +SMTP_CONN_ID = "smtp_default" +SMTP_AUTH_TYPE = "oauth2" + +# Standard settings +DEFAULT_EMAIL_PARAMS = { + "mime_subtype": "mixed", + "mime_charset": "utf-8", + "files": None, + "cc": None, + "bcc": None, + "custom_headers": None, +} + +# DAG settings +TEST_DAG_ID = "test_dag" +TEST_TASK_ID = "test_task" +TEST_TASK_STATE = None +TEST_RUN_ID = "test_run" + +# Jinja template patterns +DAG_ID_TEMPLATE_STRING = "{{dag.dag_id}}" +TI_TEMPLATE_STRING = "{{ti.task_id}}" + +SENDER_EMAIL_SUFFIX = "sender@test.com" +RECEIVER_EMAIL_SUFFIX = "receiver@test.com" +# Base test values +TEST_SENDER = f"test_{SENDER_EMAIL_SUFFIX}" +TEST_RECEIVER = f"test_{RECEIVER_EMAIL_SUFFIX}" +TEST_SUBJECT = "subject" +TEST_BODY = "body" + + +# Templated versions +@dataclass(frozen=True) +class TemplatedString: + template: str + + @property + def rendered(self) -> str: + return self.template.replace(DAG_ID_TEMPLATE_STRING, TEST_DAG_ID).replace( + TI_TEMPLATE_STRING, TEST_TASK_ID + ) + +# DAG-based templates +TEMPLATED_SENDER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}") +TEMPLATED_RECEIVER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{RECEIVER_EMAIL_SUFFIX}") +TEMPLATED_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {DAG_ID_TEMPLATE_STRING}") +TEMPLATED_BODY = TemplatedString(f"{TEST_BODY} {DAG_ID_TEMPLATE_STRING}") -NUM_TRY = 0 +# Task-based templates +TEMPLATED_TI_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {TI_TEMPLATE_STRING}") +TEMPLATED_TI_SENDER = TemplatedString(f"{TI_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}") class TestSmtpNotifier: @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier(_self, mock_smtphook_hook, create_dag_without_db): notifier = send_smtp_notification( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) - notifier({"dag": create_dag_without_db("test_notifier")}) + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_notifier_class(self, mock_smtphook_hook, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) - notifier({"dag": create_dag_without_db("test_notifier")}) + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_templated(self, mock_smtphook_hook, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com {{dag.dag_id}}", - to="test_reciver@test.com {{dag.dag_id}}", - subject="subject {{dag.dag_id}}", - html_content="body {{dag.dag_id}}", + from_email=TEMPLATED_SENDER.template, + to=TEMPLATED_RECEIVER.template, + subject=TEMPLATED_SUBJECT.template, + html_content=TEMPLATED_BODY.template, ) - context = {"dag": create_dag_without_db("test_notifier")} + context = {"dag": create_dag_without_db(TEST_DAG_ID)} notifier(context) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com test_notifier", - to="test_reciver@test.com test_notifier", - subject="subject test_notifier", - html_content="body test_notifier", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEMPLATED_SENDER.rendered, + to=TEMPLATED_RECEIVER.rendered, + subject=TEMPLATED_SUBJECT.rendered, + html_content=TEMPLATED_BODY.rendered, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_defaults(self, mock_smtphook_hook, create_dag_without_db, mock_task_instance): # TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+ mock_ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test", - try_number=NUM_TRY, + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, + try_number=TRY_NUMBER, max_tries=0, - state=None, + state=TEST_TASK_STATE, ) - context = {"dag": create_dag_without_db("test_dag"), "ti": mock_ti} + context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": mock_ti} notifier = SmtpNotifier( - from_email="any email", - to="test_reciver@test.com", + from_email=TEST_SENDER, + to=TEST_RECEIVER, ) mock_smtphook_hook.return_value.__enter__.return_value.subject_template = None mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = None + notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="any email", - to="test_reciver@test.com", - subject="DAG test_dag - Task op - Run ID test in State None", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=f"DAG {TEST_DAG_ID} - Task {TEST_TASK_ID} - Run ID {TEST_RUN_ID} in State {TEST_TASK_STATE}", html_content=mock.ANY, - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) content = mock_smtphook_hook.return_value.__enter__().send_email_smtp.call_args.kwargs["html_content"] - assert f"{NUM_TRY} of 1" in content + assert f"{TRY_NUMBER} of 1" in content @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_nondefault_connection_extra( @@ -148,60 +181,55 @@ def test_notifier_with_nondefault_connection_extra( ): # TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+ ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test_run", - try_number=NUM_TRY, + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, + try_number=TRY_NUMBER, max_tries=0, - state=None, + state=TEST_TASK_STATE, ) - context = {"dag": create_dag_without_db("test_dag"), "ti": ti} + context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": ti} with ( tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, ): - f_subject.write("Task {{ ti.task_id }} failed") + f_subject.write(TEMPLATED_TI_SUBJECT.template) f_subject.flush() - f_content.write("Mock content goes here") + f_content.write(TEST_BODY) f_content.flush() - mock_smtphook_hook.return_value.__enter__.return_value.from_email = "{{ ti.task_id }}@test.com" + mock_smtphook_hook.return_value.__enter__.return_value.from_email = TEMPLATED_TI_SENDER.template mock_smtphook_hook.return_value.__enter__.return_value.subject_template = f_subject.name mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = f_content.name - notifier = SmtpNotifier( - to="test_reciver@test.com", - ) + + notifier = SmtpNotifier(to=TEST_RECEIVER) notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="op@test.com", - to="test_reciver@test.com", - subject="Task op failed", - html_content="Mock content goes here", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEMPLATED_TI_SENDER.rendered, + to=TEST_RECEIVER, + subject=TEMPLATED_TI_SUBJECT.rendered, + html_content=TEST_BODY, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - auth_type="oauth2", - subject="subject", - html_content="body", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + auth_type=SMTP_AUTH_TYPE, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) - notifier({"dag": create_dag_without_db("test_notifier")}) + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.assert_called_once_with( - smtp_conn_id="smtp_default", - auth_type="oauth2", + smtp_conn_id=SMTP_CONN_ID, + auth_type=SMTP_AUTH_TYPE, ) @@ -223,25 +251,20 @@ def mock_smtp_hook(self, mock_smtp_client): @pytest.mark.asyncio async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) - await notifier.async_notify({"dag": create_dag_without_db("test_notifier")}) + await notifier.async_notify({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtp_client.asend_email_smtp.assert_called_once_with( - smtp_conn_id="smtp_default", - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + **DEFAULT_EMAIL_PARAMS, ) @pytest.mark.asyncio @@ -249,53 +272,43 @@ async def test_async_notifier_with_notifier_class( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db ): notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - context={"dag": create_dag_without_db("test_notifier")}, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + context={"dag": create_dag_without_db(TEST_DAG_ID)}, ) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( - smtp_conn_id="smtp_default", - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + **DEFAULT_EMAIL_PARAMS, ) @pytest.mark.asyncio async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com {{dag.dag_id}}", - to="test_reciver@test.com {{dag.dag_id}}", - subject="subject {{dag.dag_id}}", - html_content="body {{dag.dag_id}}", - context={"dag": create_dag_without_db("test_notifier")}, + from_email=TEMPLATED_SENDER.template, + to=TEMPLATED_RECEIVER.template, + subject=TEMPLATED_SUBJECT.template, + html_content=TEMPLATED_BODY.template, + context={"dag": create_dag_without_db(TEST_DAG_ID)}, ) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( - smtp_conn_id="smtp_default", - from_email="test_sender@test.com test_notifier", - to="test_reciver@test.com test_notifier", - subject="subject test_notifier", - html_content="body test_notifier", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + from_email=TEMPLATED_SENDER.rendered, + to=TEMPLATED_RECEIVER.rendered, + subject=TEMPLATED_SUBJECT.rendered, + html_content=TEMPLATED_BODY.rendered, + **DEFAULT_EMAIL_PARAMS, ) @pytest.mark.asyncio @@ -307,26 +320,21 @@ async def test_async_notifier_with_defaults( mock_smtp_client.from_email = None notifier = SmtpNotifier( - from_email="any email", - to="test_reciver@test.com", - subject="subject", - html_content="body", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, ) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( - smtp_conn_id="smtp_default", - from_email="any email", - to="test_reciver@test.com", - subject="subject", + smtp_conn_id=SMTP_CONN_ID, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, html_content=mock.ANY, - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + **DEFAULT_EMAIL_PARAMS, ) @pytest.mark.asyncio @@ -337,30 +345,25 @@ async def test_async_notifier_with_nondefault_connection_extra( tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, ): - f_subject.write("Connection Default Subject") + f_subject.write(TEST_SUBJECT) f_subject.flush() - f_content.write("Mock content goes here") + f_content.write(TEST_BODY) f_content.flush() - mock_smtp_client.from_email = "connection_default@test.com" + mock_smtp_client.from_email = TEST_SENDER mock_smtp_client.subject_template = f_subject.name mock_smtp_client.html_content_template = f_content.name - notifier = SmtpNotifier(to="test_reciver@test.com") + notifier = SmtpNotifier(to=TEST_RECEIVER) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( - smtp_conn_id="smtp_default", - from_email="connection_default@test.com", - to="test_reciver@test.com", - subject="Connection Default Subject", - html_content="Mock content goes here", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + **DEFAULT_EMAIL_PARAMS, ) From aee9962e795b4742e40dabf41ab8456e6de7e26f Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Sep 2025 20:50:37 -0700 Subject: [PATCH 04/10] fix imports --- providers/smtp/src/airflow/providers/smtp/hooks/smtp.py | 2 +- providers/smtp/tests/unit/smtp/notifications/test_smtp.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py index e9dd2347b9375..d0bd9fb17bded 100644 --- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py @@ -46,7 +46,7 @@ if TYPE_CHECKING: try: from airflow.sdk import Connection - except ImportError: + except (ImportError, ModuleNotFoundError): from airflow.models.connection import Connection # type: ignore[assignment] diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 33b1b1dd75617..367913c50dbbd 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -24,10 +24,7 @@ import pytest -from airflow.providers.smtp.notifications.smtp import ( - SmtpNotifier, - send_smtp_notification, -) +from airflow.providers.smtp.notifications.smtp import SmtpNotifier, send_smtp_notification TRY_NUMBER = 0 From 7c9fd4ef22fd6f04f1cd443fe89b3707af9be9b9 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 8 Sep 2025 12:19:27 -0700 Subject: [PATCH 05/10] fix imports and some small style changes --- .../providers/smtp/notifications/smtp.py | 4 +- .../smtp/tests/unit/smtp/hooks/test_smtp.py | 6 +- .../unit/smtp/notifications/test_smtp.py | 74 ++++++------------- 3 files changed, 28 insertions(+), 56 deletions(-) diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index c6a00ef90d981..bca6f866d327e 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -112,7 +112,7 @@ def hook(self) -> SmtpHook: def _build_email_content(self, smtp: SmtpHook, context: Context, use_templates: bool = True): # TODO: use_templates is temporary until templating on the Triggerer is sorted out. - + # Once that is done, we can remove that flag. fields_to_re_render = [] if self.from_email is None: if smtp.from_email is not None: @@ -166,7 +166,7 @@ async def async_notify(self, context: Context): """Send a email via smtp server (async).""" async with self.hook as smtp: # TODO: use_templates is temporary until templating on the Triggerer is sorted out. - # Once that iks done, we can remove that flag. + # Once that is done, we can remove that flag. self._build_email_content(smtp, context, use_templates=False) await smtp.asend_email_smtp( smtp_conn_id=self.smtp_conn_id, diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index 397c1412d2766..166aff4461e29 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -29,9 +29,13 @@ import pytest from airflow.exceptions import AirflowException -from airflow.models import Connection from airflow.providers.smtp.hooks.smtp import SmtpHook, build_xoauth2_string +try: + from airflow.sdk import Connection +except (ImportError, ModuleNotFoundError): + from airflow.models.connection import Connection # type: ignore[assignment] + smtplib_string = "airflow.providers.smtp.hooks.smtp.smtplib" FROM_EMAIL = "from@example.com" diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 367913c50dbbd..a0c4a4a8b2e12 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -59,6 +59,13 @@ TEST_SUBJECT = "subject" TEST_BODY = "body" +NOTIFIER_DEFAULT_PARAMS = { + "from_email": TEST_SENDER, + "to": TEST_RECEIVER, + "subject": TEST_SUBJECT, + "html_content": TEST_BODY, +} + # Templated versions @dataclass(frozen=True) @@ -86,30 +93,17 @@ def rendered(self) -> str: class TestSmtpNotifier: @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier(_self, mock_smtphook_hook, create_dag_without_db): - notifier = send_smtp_notification( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, - ) + notifier = send_smtp_notification(**NOTIFIER_DEFAULT_PARAMS) notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, + **NOTIFIER_DEFAULT_PARAMS, smtp_conn_id=SMTP_CONN_ID, **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_notifier_class(self, mock_smtphook_hook, create_dag_without_db): - notifier = SmtpNotifier( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, - ) + notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS) notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( from_email=TEST_SENDER, @@ -129,7 +123,9 @@ def test_notifier_templated(self, mock_smtphook_hook, create_dag_without_db): html_content=TEMPLATED_BODY.template, ) context = {"dag": create_dag_without_db(TEST_DAG_ID)} + notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( from_email=TEMPLATED_SENDER.rendered, to=TEMPLATED_RECEIVER.rendered, @@ -214,20 +210,11 @@ def test_notifier_with_nondefault_connection_extra( @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_without_db): - notifier = SmtpNotifier( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - auth_type=SMTP_AUTH_TYPE, - subject=TEST_SUBJECT, - html_content=TEST_BODY, - ) + notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS, auth_type=SMTP_AUTH_TYPE) notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) - mock_smtphook_hook.assert_called_once_with( - smtp_conn_id=SMTP_CONN_ID, - auth_type=SMTP_AUTH_TYPE, - ) + mock_smtphook_hook.assert_called_once_with(smtp_conn_id=SMTP_CONN_ID, auth_type=SMTP_AUTH_TYPE) class TestSmtpNotifierAsync: @@ -248,19 +235,13 @@ def mock_smtp_hook(self, mock_smtp_client): @pytest.mark.asyncio async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): notifier = SmtpNotifier( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} ) await notifier.async_notify({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtp_client.asend_email_smtp.assert_called_once_with( smtp_conn_id=SMTP_CONN_ID, - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, + **NOTIFIER_DEFAULT_PARAMS, **DEFAULT_EMAIL_PARAMS, ) @@ -269,21 +250,14 @@ async def test_async_notifier_with_notifier_class( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db ): notifier = SmtpNotifier( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, - context={"dag": create_dag_without_db(TEST_DAG_ID)}, + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} ) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( smtp_conn_id=SMTP_CONN_ID, - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, + **NOTIFIER_DEFAULT_PARAMS, **DEFAULT_EMAIL_PARAMS, ) @@ -317,20 +291,14 @@ async def test_async_notifier_with_defaults( mock_smtp_client.from_email = None notifier = SmtpNotifier( - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=TEST_BODY, + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} ) await notifier mock_smtp_client.asend_email_smtp.assert_called_once_with( smtp_conn_id=SMTP_CONN_ID, - from_email=TEST_SENDER, - to=TEST_RECEIVER, - subject=TEST_SUBJECT, - html_content=mock.ANY, + **NOTIFIER_DEFAULT_PARAMS, **DEFAULT_EMAIL_PARAMS, ) @@ -352,7 +320,7 @@ async def test_async_notifier_with_nondefault_connection_extra( mock_smtp_client.subject_template = f_subject.name mock_smtp_client.html_content_template = f_content.name - notifier = SmtpNotifier(to=TEST_RECEIVER) + notifier = SmtpNotifier(to=TEST_RECEIVER, context={"dag": create_dag_without_db(TEST_DAG_ID)}) await notifier From e7f147a52bfa35d4a427a43a549384c790192f94 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 8 Sep 2025 12:32:23 -0700 Subject: [PATCH 06/10] Add backcompat checks --- .../airflow/providers/slack/notifications/slack_webhook.py | 7 ++++++- providers/smtp/tests/unit/smtp/hooks/test_smtp.py | 2 ++ providers/smtp/tests/unit/smtp/notifications/test_smtp.py | 7 +++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py index d1125224ecf58..40dfc245e8aca 100644 --- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from airflow.providers.common.compat.notifier import BaseNotifier +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook if TYPE_CHECKING: @@ -64,7 +65,11 @@ def __init__( retry_handlers: list[RetryHandler] | None = None, **kwargs, ): - super().__init__(**kwargs) + if AIRFLOW_V_3_1_PLUS: + # Support for passing contest was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.slack_webhook_conn_id = slack_webhook_conn_id self.text = text self.attachments = attachments diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index 166aff4461e29..fdedc5647d924 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -30,6 +30,7 @@ from airflow.exceptions import AirflowException from airflow.providers.smtp.hooks.smtp import SmtpHook, build_xoauth2_string +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS try: from airflow.sdk import Connection @@ -486,6 +487,7 @@ def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_witho @pytest.mark.asyncio +@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0") class TestSmtpHookAsync: """Tests for async functionality in SmtpHook.""" diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index a0c4a4a8b2e12..6ae892a292724 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -25,6 +25,7 @@ import pytest from airflow.providers.smtp.notifications.smtp import SmtpNotifier, send_smtp_notification +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS TRY_NUMBER = 0 @@ -217,6 +218,8 @@ def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_w mock_smtphook_hook.assert_called_once_with(smtp_conn_id=SMTP_CONN_ID, auth_type=SMTP_AUTH_TYPE) +@pytest.mark.asyncio +@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0") class TestSmtpNotifierAsync: @pytest.fixture def mock_smtp_client(self): @@ -245,7 +248,6 @@ async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag **DEFAULT_EMAIL_PARAMS, ) - @pytest.mark.asyncio async def test_async_notifier_with_notifier_class( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db ): @@ -261,7 +263,6 @@ async def test_async_notifier_with_notifier_class( **DEFAULT_EMAIL_PARAMS, ) - @pytest.mark.asyncio async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): notifier = SmtpNotifier( from_email=TEMPLATED_SENDER.template, @@ -282,7 +283,6 @@ async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, **DEFAULT_EMAIL_PARAMS, ) - @pytest.mark.asyncio async def test_async_notifier_with_defaults( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance ): @@ -302,7 +302,6 @@ async def test_async_notifier_with_defaults( **DEFAULT_EMAIL_PARAMS, ) - @pytest.mark.asyncio async def test_async_notifier_with_nondefault_connection_extra( self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance ): From 1310ddc60cd7b46d77ed54a286a968c7079e0492 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 8 Sep 2025 14:33:12 -0700 Subject: [PATCH 07/10] static checks --- providers/smtp/tests/unit/smtp/hooks/test_smtp.py | 1 + providers/smtp/tests/unit/smtp/notifications/test_smtp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index fdedc5647d924..894f121c52758 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -30,6 +30,7 @@ from airflow.exceptions import AirflowException from airflow.providers.smtp.hooks.smtp import SmtpHook, build_xoauth2_string + from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS try: diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 6ae892a292724..6f8fcf60ac848 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -25,6 +25,7 @@ import pytest from airflow.providers.smtp.notifications.smtp import SmtpNotifier, send_smtp_notification + from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS TRY_NUMBER = 0 From 4b30ad37ee8156e9a2cc9d77d22e7166d7e6b59f Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 9 Sep 2025 11:11:00 -0700 Subject: [PATCH 08/10] fix imports --- providers/smtp/tests/unit/smtp/hooks/test_smtp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index 894f121c52758..e3fad4d7ae3d0 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -33,9 +33,9 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS -try: +if AIRFLOW_V_3_1_PLUS: from airflow.sdk import Connection -except (ImportError, ModuleNotFoundError): +else: from airflow.models.connection import Connection # type: ignore[assignment] smtplib_string = "airflow.providers.smtp.hooks.smtp.smtplib" From 2e822c725c212ae396391ce94a2cbf396f27923a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 10 Sep 2025 17:12:51 -0700 Subject: [PATCH 09/10] PR fixes --- .../slack/notifications/slack_webhook.py | 2 +- .../providers/smtp/notifications/smtp.py | 29 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py index 40dfc245e8aca..7b153ea02d361 100644 --- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py @@ -66,7 +66,7 @@ def __init__( **kwargs, ): if AIRFLOW_V_3_1_PLUS: - # Support for passing contest was added in 3.1.0 + # Support for passing context was added in 3.1.0 super().__init__(**kwargs) else: super().__init__() diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index bca6f866d327e..81f19855d9a84 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -24,6 +24,7 @@ from airflow.providers.common.compat.notifier import BaseNotifier from airflow.providers.smtp.hooks.smtp import SmtpHook +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS if TYPE_CHECKING: from airflow.sdk import Context @@ -85,7 +86,11 @@ def __init__( template: str | None = None, **kwargs, ): - super().__init__(**kwargs) + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.smtp_conn_id = smtp_conn_id self.from_email = from_email self.to = to @@ -110,9 +115,7 @@ def hook(self) -> SmtpHook: """Smtp Events Hook.""" return SmtpHook(smtp_conn_id=self.smtp_conn_id, auth_type=self.auth_type) - def _build_email_content(self, smtp: SmtpHook, context: Context, use_templates: bool = True): - # TODO: use_templates is temporary until templating on the Triggerer is sorted out. - # Once that is done, we can remove that flag. + def _build_email_content(self, smtp: SmtpHook, context: Context): fields_to_re_render = [] if self.from_email is None: if smtp.from_email is not None: @@ -140,15 +143,15 @@ def _build_email_content(self, smtp: SmtpHook, context: Context, use_templates: ).as_posix() self.html_content = self._read_template(smtp_default_templated_html_content_path) fields_to_re_render.append("html_content") - if fields_to_re_render and use_templates: - jinja_env = self.get_template_env(dag=context["dag"]) + if fields_to_re_render: + jinja_env = self.get_template_env(dag=context.get("dag")) self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) def notify(self, context: Context): """Send a email via smtp server.""" - with self.hook as smtp: - self._build_email_content(smtp, context) - smtp.send_email_smtp( + with self.hook as smtp_hook: + self._build_email_content(smtp_hook, context) + smtp_hook.send_email_smtp( smtp_conn_id=self.smtp_conn_id, from_email=self.from_email, to=self.to, @@ -164,11 +167,9 @@ def notify(self, context: Context): async def async_notify(self, context: Context): """Send a email via smtp server (async).""" - async with self.hook as smtp: - # TODO: use_templates is temporary until templating on the Triggerer is sorted out. - # Once that is done, we can remove that flag. - self._build_email_content(smtp, context, use_templates=False) - await smtp.asend_email_smtp( + async with self.hook as smtp_hook: + self._build_email_content(smtp_hook, context) + await smtp_hook.asend_email_smtp( smtp_conn_id=self.smtp_conn_id, from_email=self.from_email, to=self.to, From b61a98d30bc8e4885dc7ecefb5767bf518a688df Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 10 Sep 2025 17:27:37 -0700 Subject: [PATCH 10/10] static fixes --- providers/smtp/src/airflow/providers/smtp/notifications/smtp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index 81f19855d9a84..c77bc25e731b1 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -24,6 +24,7 @@ from airflow.providers.common.compat.notifier import BaseNotifier from airflow.providers.smtp.hooks.smtp import SmtpHook + from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS if TYPE_CHECKING: