diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py index d1125224ecf58..7b153ea02d361 100644 --- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py +++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING from airflow.providers.common.compat.notifier import BaseNotifier +from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook if TYPE_CHECKING: @@ -64,7 +65,11 @@ def __init__( retry_handlers: list[RetryHandler] | None = None, **kwargs, ): - super().__init__(**kwargs) + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.slack_webhook_conn_id = slack_webhook_conn_id self.text = text self.attachments = attachments diff --git a/providers/smtp/pyproject.toml b/providers/smtp/pyproject.toml index ac31d92cb0f95..98bebb17ff874 100644 --- a/providers/smtp/pyproject.toml +++ b/providers/smtp/pyproject.toml @@ -59,6 +59,7 @@ requires-python = ">=3.10" dependencies = [ "apache-airflow>=2.10.0", "apache-airflow-providers-common-compat>=1.6.1", + "aiosmtplib>=0.1.6", ] [dependency-groups] diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py index 53982137dab51..d0bd9fb17bded 100644 --- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py @@ -18,7 +18,8 @@ """ Search in emails for a specific attachment and also to download it. -It uses the smtplib library that is already integrated in python 3. +It uses the smtplib library that is already integrated in python 3 for +synchronous connections or aiosmtplib for async connections. """ from __future__ import annotations @@ -33,16 +34,19 @@ from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, cast +import aiosmtplib + from airflow.exceptions import AirflowException, AirflowNotFoundException from airflow.providers.smtp.version_compat import BaseHook if TYPE_CHECKING: try: from airflow.sdk import Connection - except ImportError: + except (ImportError, ModuleNotFoundError): from airflow.models.connection import Connection # type: ignore[assignment] @@ -71,15 +75,37 @@ def __init__(self, smtp_conn_id: str = default_conn_name, auth_type: str = "basi super().__init__() self.smtp_conn_id = smtp_conn_id self.smtp_connection: Connection | None = None - self.smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | None = None + self._smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | aiosmtplib.SMTP | None = None self._auth_type = auth_type self._access_token: str | None = None def __enter__(self) -> SmtpHook: return self.get_conn() + async def __aenter__(self) -> SmtpHook: + return await self.aget_conn() + def __exit__(self, exc_type, exc_val, exc_tb): - self.smtp_client.close() + self._smtp_client.close() + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self._smtp_client: + await self._smtp_client.quit() + + def _setup_oauth2(self) -> tuple[str, str]: + """ + Set up OAuth2 credentials and return token and user identity. + + :return: Tuple of (user_identity, access_token) + """ + if not self._access_token: + self._access_token = self._get_oauth2_token() + + user_identity = self.smtp_user or self.from_email + if user_identity is None: + raise AirflowException("smtp_user or from_email must be set for OAuth2 authentication") + + return user_identity, self._access_token def get_conn(self) -> SmtpHook: """ @@ -90,7 +116,7 @@ def get_conn(self) -> SmtpHook: :return: an authorized SmtpHook object. """ - if not self.smtp_client: + if not self._smtp_client: try: self.smtp_connection = self.get_connection(self.smtp_conn_id) except AirflowNotFoundException: @@ -98,13 +124,13 @@ def get_conn(self) -> SmtpHook: for attempt in range(1, self.smtp_retry_limit + 1): try: - self.smtp_client = self._build_client() + self._smtp_client = self._build_client() except smtplib.SMTPServerDisconnected: if attempt == self.smtp_retry_limit: raise AirflowException("Unable to connect to smtp server") else: if self.smtp_starttls: - self.smtp_client.starttls() + self._smtp_client.starttls() # choose auth if self._auth_type == "oauth2": @@ -115,41 +141,94 @@ def get_conn(self) -> SmtpHook: raise AirflowException( "smtp_user or from_email must be set for OAuth2 authentication" ) - self.smtp_client.auth( + self._smtp_client.auth( "XOAUTH2", lambda _=None: build_xoauth2_string(user_identity, self._access_token), ) elif self.smtp_user and self.smtp_password: - self.smtp_client.login(self.smtp_user, self.smtp_password) + self._smtp_client.login(self.smtp_user, self.smtp_password) break return self - def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP: - SMTP: type[smtplib.SMTP_SSL] | type[smtplib.SMTP] - if self.use_ssl: - SMTP = smtplib.SMTP_SSL - else: - SMTP = smtplib.SMTP + async def aget_conn(self) -> SmtpHook: + """ + Login to the smtp server (async). + + .. note:: Please call this Hook as context manager via `with` + to automatically open and close the connection to the smtp server. + + :return: an authorized SmtpHook object. + """ + if not self._smtp_client: + try: + self.smtp_connection = await self.aget_connection(self.smtp_conn_id) + except AirflowNotFoundException: + raise AirflowException("SMTP connection is not found.") + + for attempt in range(1, self.smtp_retry_limit + 1): + try: + async_client = await self._abuild_client() + self._smtp_client = async_client + except aiosmtplib.errors.SMTPServerDisconnected: + if attempt == self.smtp_retry_limit: + raise AirflowException("Unable to connect to smtp server") + else: + if self.smtp_starttls: + await async_client.starttls() + + if self.smtp_user and self.smtp_password: + await async_client.auth_login(self.smtp_user, self.smtp_password) + break + + return self + + def _build_client_kwargs(self, is_async: bool) -> dict[str, Any]: + """Build kwargs appropriate for sync or async SMTP client.""" + valid_contexts = (None, "default", "none") # Values accepted for ssl_context configuration + + kwargs: dict[str, Any] = {"timeout": self.timeout} - smtp_kwargs: dict[str, Any] = {"host": self.host} if self.port: - smtp_kwargs["port"] = self.port - smtp_kwargs["timeout"] = self.timeout - - if self.use_ssl: - ssl_context_string = self.ssl_context - if ssl_context_string is None or ssl_context_string == "default": - ssl_context = ssl.create_default_context() - elif ssl_context_string == "none": - ssl_context = None - else: - raise RuntimeError( - f"The connection extra field `ssl_context` must " - f"be set to 'default' or 'none' but it is set to '{ssl_context_string}'." - ) - smtp_kwargs["context"] = ssl_context - return SMTP(**smtp_kwargs) + kwargs["port"] = self.port + + if is_async: + kwargs["hostname"] = self.host + kwargs["use_tls"] = self.use_ssl + kwargs["start_tls"] = self.smtp_starttls if not self.use_ssl else None + else: + kwargs["host"] = self.host + if self.use_ssl: + if self.ssl_context not in valid_contexts: + raise RuntimeError( + f"The connection extra field `ssl_context` must " + f"be set to 'default' or 'none' but it is set to '{self.ssl_context}'." + ) + kwargs["context"] = None if self.ssl_context == "none" else ssl.create_default_context() + + return kwargs + + def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP: + """Build a synchronous SMTP client.""" + client: type[smtplib.SMTP_SSL] | type[smtplib.SMTP] = ( + smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP + ) + return client(**self._build_client_kwargs(is_async=False)) + + async def _abuild_client(self) -> aiosmtplib.SMTP: + """ + Build an asynchronous SMTP client. + + Unlike the synchronous client (which connects automatically when instantiated), + aiosmtplib requires explicit connect() and ehlo() calls. We handle those here + to keep the async implementation details contained and make aget_conn behavior + match get_conn. + """ + async_client = aiosmtplib.SMTP(**self._build_client_kwargs(is_async=True)) + await async_client.connect() + await async_client.ehlo() + + return async_client @classmethod def get_connection_form_widgets(cls) -> dict[str, Any]: @@ -198,15 +277,82 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: def test_connection(self) -> tuple[bool, str]: """Test SMTP connectivity from UI.""" try: - smtp_client = self.get_conn().smtp_client + smtp_client = self.get_conn()._smtp_client if smtp_client: - status = smtp_client.noop()[0] + status = smtp_client.noop() if status == 250: return True, "Connection successfully tested" except Exception as e: return False, str(e) return False, "Failed to establish connection" + async def atest_connection(self) -> tuple[bool, str]: + """Test SMTP connectivity (async).""" + try: + smtp_client = (await self.aget_conn())._smtp_client + if smtp_client is None: + return False, "SMTP client not initialized" + + if isinstance(smtp_client, aiosmtplib.SMTP): + async_client: aiosmtplib.SMTP = smtp_client + response = await async_client.noop() + if response.code == 250: + return True, "Async connection successfully tested" + return False, f"Connection test failed with code: {response.code}" + + except Exception as e: + return False, str(e) + return False, "Failed to establish connection" + + def _build_message( + self, + to: str | Iterable[str], + subject: str | None = None, + html_content: str | None = None, + from_email: str | None = None, + files: list[str] | None = None, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + custom_headers: dict[str, Any] | None = None, + ) -> dict[str, Any]: + if not self._smtp_client: + raise AirflowException("The 'smtp_client' should be initialized before!") + + from_email = from_email or self.from_email + if not from_email: + raise AirflowException("You should provide `from_email` or define it in the connection.") + + if not subject: + if self.subject_template is None: + raise AirflowException( + "You should provide `subject` or define `subject_template` in the connection." + ) + subject = self._read_template(self.subject_template) + + if not html_content: + if self.html_content_template is None: + raise AirflowException( + "You should provide `html_content` or define `html_content_template` in the connection." + ) + html_content = self._read_template(self.html_content_template) + + mime_msg, recipients = self._build_mime_message( + mail_from=from_email, + to=to, + subject=subject, + html_content=html_content, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + + return {"mime_msg": mime_msg, "recipients": recipients, "from_email": from_email} + def send_email_smtp( self, *, @@ -246,27 +392,9 @@ def send_email_smtp( 'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True ) """ - if not self.smtp_client: - raise AirflowException("The 'smtp_client' should be initialized before!") - from_email = from_email or self.from_email - if not from_email: - raise AirflowException("You should provide `from_email` or define it in the connection.") - if not subject: - if self.subject_template is None: - raise AirflowException( - "You should provide `subject` or define `subject_template` in the connection." - ) - subject = self._read_template(self.subject_template) - if not html_content: - if self.html_content_template is None: - raise AirflowException( - "You should provide `html_content` or define `html_content_template` in the connection." - ) - html_content = self._read_template(self.html_content_template) - - mime_msg, recipients = self._build_mime_message( - mail_from=from_email, + msg = self._build_message( to=to, + from_email=from_email, subject=subject, html_content=html_content, files=files, @@ -277,16 +405,92 @@ def send_email_smtp( custom_headers=custom_headers, ) if not dryrun: + if self._smtp_client is None: + raise AirflowException("The SMTP client is not initialized") + # Casting here to make MyPy happy. + smtp_client = cast("smtplib.SMTP_SSL | smtplib.SMTP", self._smtp_client) + for attempt in range(1, self.smtp_retry_limit + 1): try: - self.smtp_client.sendmail( - from_addr=from_email, to_addrs=recipients, msg=mime_msg.as_string() + smtp_client.sendmail( + from_addr=msg["from_email"], + to_addrs=msg["recipients"], + msg=msg["mime_msg"].as_string(), ) - except smtplib.SMTPServerDisconnected as e: + break + except Exception as e: if attempt == self.smtp_retry_limit: raise e - else: + + async def asend_email_smtp( + self, + *, + to: str | Iterable[str], + subject: str | None = None, + html_content: str | None = None, + from_email: str | None = None, + files: list[str] | None = None, + dryrun: bool = False, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + custom_headers: dict[str, Any] | None = None, + **kwargs, + ) -> None: + """ + Send an email with html content. + + :param to: Recipient email address or list of addresses. + :param subject: Email subject. If it's None, the hook will check if there is a path to a subject + file provided in the connection, and raises an exception if not. + :param html_content: Email body in HTML format. If it's None, the hook will check if there is a path + to a html content file provided in the connection, and raises an exception if not. + :param from_email: Sender email address. If it's None, the hook will check if there is an email + provided in the connection, and raises an exception if not. + :param files: List of file paths to attach to the email. + :param dryrun: If True, the email will not be sent, but all other actions will be performed. + :param cc: Carbon copy recipient email address or list of addresses. + :param bcc: Blind carbon copy recipient email address or list of addresses. + :param mime_subtype: MIME subtype of the email. + :param mime_charset: MIME charset of the email. + :param custom_headers: Dictionary of custom headers to include in the email. + :param kwargs: Additional keyword arguments. + + >>> send_email_smtp( + 'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True + ) + """ + msg = self._build_message( + to=to, + subject=subject, + html_content=html_content, + from_email=from_email, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + if self._smtp_client is None: + raise AirflowException("The SMTP client is not initialized") + # Casting here to make MyPy happy. + smtp_client = cast("aiosmtplib.SMTP", self._smtp_client) + + if not dryrun: + for attempt in range(1, self.smtp_retry_limit + 1): + try: + # The async version of sendmail only supports positional arguments for some reason. + await smtp_client.sendmail( + msg["from_email"], + msg["recipients"], + msg["mime_msg"].as_string(), + ) break + except Exception as e: + if attempt == self.smtp_retry_limit: + raise e def _build_mime_message( self, @@ -420,7 +624,7 @@ def _get_oauth2_token(self) -> str: "auth_type='oauth2' but neither 'access_token' nor client credentials supplied in connection extra." ) - @property + @cached_property def conn(self) -> Connection: if not self.smtp_connection: raise AirflowException("The smtp connection should be loaded before!") diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py index 49d463f8819e5..c77bc25e731b1 100644 --- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py +++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py @@ -20,11 +20,16 @@ from collections.abc import Iterable from functools import cached_property from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any from airflow.providers.common.compat.notifier import BaseNotifier from airflow.providers.smtp.hooks.smtp import SmtpHook +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +if TYPE_CHECKING: + from airflow.sdk import Context + class SmtpNotifier(BaseNotifier): """ @@ -80,8 +85,13 @@ def __init__( auth_type: str = "basic", *, template: str | None = None, + **kwargs, ): - super().__init__() + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() self.smtp_conn_id = smtp_conn_id self.from_email = from_email self.to = to @@ -106,40 +116,61 @@ def hook(self) -> SmtpHook: """Smtp Events Hook.""" return SmtpHook(smtp_conn_id=self.smtp_conn_id, auth_type=self.auth_type) - def notify(self, context): + def _build_email_content(self, smtp: SmtpHook, context: Context): + fields_to_re_render = [] + if self.from_email is None: + if smtp.from_email is not None: + self.from_email = smtp.from_email + else: + raise ValueError("You should provide `from_email` or define it in the connection") + fields_to_re_render.append("from_email") + if self.subject is None: + smtp_default_templated_subject_path: str + if smtp.subject_template: + smtp_default_templated_subject_path = smtp.subject_template + else: + smtp_default_templated_subject_path = ( + Path(__file__).parent / "templates" / "email_subject.jinja2" + ).as_posix() + self.subject = self._read_template(smtp_default_templated_subject_path) + fields_to_re_render.append("subject") + if self.html_content is None: + smtp_default_templated_html_content_path: str + if smtp.html_content_template: + smtp_default_templated_html_content_path = smtp.html_content_template + else: + smtp_default_templated_html_content_path = ( + Path(__file__).parent / "templates" / "email.html" + ).as_posix() + self.html_content = self._read_template(smtp_default_templated_html_content_path) + fields_to_re_render.append("html_content") + if fields_to_re_render: + jinja_env = self.get_template_env(dag=context.get("dag")) + self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) + + def notify(self, context: Context): """Send a email via smtp server.""" - with self.hook as smtp: - fields_to_re_render = [] - if self.from_email is None: - if smtp.from_email is not None: - self.from_email = smtp.from_email - else: - raise ValueError("You should provide `from_email` or define it in the connection") - fields_to_re_render.append("from_email") - if self.subject is None: - smtp_default_templated_subject_path: str - if smtp.subject_template: - smtp_default_templated_subject_path = smtp.subject_template - else: - smtp_default_templated_subject_path = ( - Path(__file__).parent / "templates" / "email_subject.jinja2" - ).as_posix() - self.subject = self._read_template(smtp_default_templated_subject_path) - fields_to_re_render.append("subject") - if self.html_content is None: - smtp_default_templated_html_content_path: str - if smtp.html_content_template: - smtp_default_templated_html_content_path = smtp.html_content_template - else: - smtp_default_templated_html_content_path = ( - Path(__file__).parent / "templates" / "email.html" - ).as_posix() - self.html_content = self._read_template(smtp_default_templated_html_content_path) - fields_to_re_render.append("html_content") - if fields_to_re_render: - jinja_env = self.get_template_env(dag=context["dag"]) - self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set()) - smtp.send_email_smtp( + with self.hook as smtp_hook: + self._build_email_content(smtp_hook, context) + smtp_hook.send_email_smtp( + smtp_conn_id=self.smtp_conn_id, + from_email=self.from_email, + to=self.to, + subject=self.subject, + html_content=self.html_content, + files=self.files, + cc=self.cc, + bcc=self.bcc, + mime_subtype=self.mime_subtype, + mime_charset=self.mime_charset, + custom_headers=self.custom_headers, + ) + + async def async_notify(self, context: Context): + """Send a email via smtp server (async).""" + async with self.hook as smtp_hook: + self._build_email_content(smtp_hook, context) + await smtp_hook.asend_email_smtp( smtp_conn_id=self.smtp_conn_id, from_email=self.from_email, to=self.to, diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py index fc09281a306fc..e3fad4d7ae3d0 100644 --- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py @@ -22,17 +22,49 @@ import smtplib import tempfile from email.mime.application import MIMEApplication -from unittest.mock import Mock, call, patch +from unittest import mock +from unittest.mock import AsyncMock, Mock, call, patch +import aiosmtplib import pytest from airflow.exceptions import AirflowException -from airflow.models import Connection from airflow.providers.smtp.hooks.smtp import SmtpHook, build_xoauth2_string +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS + +if AIRFLOW_V_3_1_PLUS: + from airflow.sdk import Connection +else: + from airflow.models.connection import Connection # type: ignore[assignment] + smtplib_string = "airflow.providers.smtp.hooks.smtp.smtplib" -TEST_EMAILS = ["test1@example.com", "test2@example.com"] +FROM_EMAIL = "from@example.com" +TO_EMAIL = "to@example.com" +TEST_EMAILS = [FROM_EMAIL, TO_EMAIL] + +CONN_TYPE = "smtp" +SMTP_HOST = "smtp.example.com" +SMTP_LOGIN = "smtp_user" +SMTP_PASSWORD = "smtp_password" +ACCESS_TOKEN = "test-token" + +CONN_ID_DEFAULT = "smtp_default" +CONN_ID_NONSSL = "smtp_nonssl" +CONN_ID_SSL_EXTRA = "smtp_ssl_extra" +CONN_ID_OAUTH = "smtp_oauth2" + +DEFAULT_PORT = 465 +NONSSL_PORT = 587 + +DEFAULT_TIMEOUT = 30 +DEFAULT_RETRY_LIMIT = 5 + +TEST_SUBJECT = "test subject" +TEST_BODY = "Test" + +SERVER_DISCONNECTED_ERROR = aiosmtplib.errors.SMTPServerDisconnected("Server disconnected") def _create_fake_smtp(mock_smtplib, use_ssl=True): @@ -53,110 +85,123 @@ class TestSmtpHook: def setup_connections(self, create_connection_without_db): create_connection_without_db( Connection( - conn_id="smtp_default", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps(dict(from_email="from")), + conn_id=CONN_ID_DEFAULT, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps(dict(from_email=FROM_EMAIL)), ) ) create_connection_without_db( Connection( - conn_id="smtp_nonssl", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + conn_id=CONN_ID_NONSSL, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) ) create_connection_without_db( Connection( - conn_id="smtp_oauth2", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from", access_token="test-token")), + conn_id=CONN_ID_OAUTH, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL, access_token=ACCESS_TOKEN)), + ) + ) + create_connection_without_db( + Connection( + conn_id=CONN_ID_SSL_EXTRA, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=None, + password="None", + port=DEFAULT_PORT, + extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email=FROM_EMAIL)), ) ) + @pytest.mark.parametrize( + "conn_id, use_ssl, expected_port, create_context", + [ + pytest.param(CONN_ID_DEFAULT, True, DEFAULT_PORT, True, id="ssl-connection"), + pytest.param(CONN_ID_NONSSL, False, NONSSL_PORT, False, id="non-ssl-connection"), + ], + ) @patch(smtplib_string) @patch("ssl.create_default_context") - def test_connect_and_disconnect(self, create_default_context, mock_smtplib): - mock_conn = _create_fake_smtp(mock_smtplib) + def test_connect_and_disconnect( + self, create_default_context, mock_smtplib, conn_id, use_ssl, expected_port, create_context + ): + """Test sync connection with different configurations.""" + mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=use_ssl) - with SmtpHook(): + with SmtpHook(smtp_conn_id=conn_id): pass - assert create_default_context.called - mock_smtplib.SMTP_SSL.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value - ) - mock_conn.login.assert_called_once_with("smtp_user", "smtp_password") - assert mock_conn.close.call_count == 1 - @patch(smtplib_string) - def test_connect_and_disconnect_via_nonssl(self, mock_smtplib): - mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False) - - with SmtpHook(smtp_conn_id="smtp_nonssl"): - pass + if create_context: + assert create_default_context.called + mock_smtplib.SMTP_SSL.assert_called_once_with( + host=SMTP_HOST, + port=expected_port, + timeout=DEFAULT_TIMEOUT, + context=create_default_context.return_value, + ) + else: + mock_smtplib.SMTP.assert_called_once_with( + host=SMTP_HOST, + port=expected_port, + timeout=DEFAULT_TIMEOUT, + ) - mock_smtplib.SMTP.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) - mock_conn.login.assert_called_once_with("smtp_user", "smtp_password") + mock_conn.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD) assert mock_conn.close.call_count == 1 @patch(smtplib_string) def test_get_email_address_single_email(self, mock_smtplib): with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com") == ["test1@example.com"] - - @patch(smtplib_string) - def test_get_email_address_comma_sep_string(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com, test2@example.com") == TEST_EMAILS - - @patch(smtplib_string) - def test_get_email_address_colon_sep_string(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert smtp_hook._get_email_address_list("test1@example.com; test2@example.com") == TEST_EMAILS - - @patch(smtplib_string) - def test_get_email_address_list(self, mock_smtplib): - with SmtpHook() as smtp_hook: - assert ( - smtp_hook._get_email_address_list(["test1@example.com", "test2@example.com"]) == TEST_EMAILS - ) - + assert smtp_hook._get_email_address_list(FROM_EMAIL) == [FROM_EMAIL] + + @pytest.mark.parametrize( + "email_input", + [ + pytest.param(f"{FROM_EMAIL}, {TO_EMAIL}", id="comma_separated"), + pytest.param(f"{FROM_EMAIL}; {TO_EMAIL}", id="semicolon_separated"), + pytest.param([FROM_EMAIL, TO_EMAIL], id="list_input"), + pytest.param((FROM_EMAIL, TO_EMAIL), id="tuple_input"), + ], + ) @patch(smtplib_string) - def test_get_email_address_tuple(self, mock_smtplib): + def test_get_email_address_parsing(self, mock_smtplib, email_input): with SmtpHook() as smtp_hook: - assert ( - smtp_hook._get_email_address_list(("test1@example.com", "test2@example.com")) == TEST_EMAILS - ) - - @patch(smtplib_string) - def test_get_email_address_invalid_type(self, mock_smtplib): - with pytest.raises(TypeError): - with SmtpHook() as smtp_hook: - smtp_hook._get_email_address_list(1) - + assert smtp_hook._get_email_address_list(email_input) == TEST_EMAILS + + @pytest.mark.parametrize( + "invalid_input", + [ + pytest.param(1, id="invalid_scalar_type"), + pytest.param([FROM_EMAIL, 2], id="invalid_type_in_list"), + ], + ) @patch(smtplib_string) - def test_get_email_address_invalid_type_in_iterable(self, mock_smtplib): + def test_get_email_address_invalid_types(self, mock_smtplib, invalid_input): with pytest.raises(TypeError): with SmtpHook() as smtp_hook: - smtp_hook._get_email_address_list(["test1@example.com", 2]) + smtp_hook._get_email_address_list(invalid_input) @patch(smtplib_string) def test_build_mime_message(self, mock_smtplib): - mail_from = "from@example.com" - mail_to = "to@example.com" - subject = "test subject" - html_content = "Test" + mail_from = FROM_EMAIL + mail_to = TO_EMAIL + subject = TEST_SUBJECT + html_content = TEST_BODY custom_headers = {"Reply-To": "reply_to@example.com"} with SmtpHook() as smtp_hook: msg, recipients = smtp_hook._build_mime_message( @@ -181,15 +226,15 @@ def test_send_smtp(self, mock_smtplib): attachment.write(b"attachment") attachment.seek(0) smtp_hook.send_email_smtp( - to="to", subject="subject", html_content="content", files=[attachment.name] + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, files=[attachment.name] ) assert mock_send_mime.called _, call_args = mock_send_mime.call_args - assert call_args["from_addr"] == "from" - assert call_args["to_addrs"] == ["to"] + assert call_args["from_addr"] == FROM_EMAIL + assert call_args["to_addrs"] == [TO_EMAIL] msg = call_args["msg"] - assert "Subject: subject" in msg - assert "From: from" in msg + assert f"Subject: {TEST_SUBJECT}" in msg + assert f"From: {FROM_EMAIL}" in msg filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"' assert filename in msg mimeapp = MIMEApplication("attachment") @@ -199,148 +244,150 @@ def test_send_smtp(self, mock_smtplib): @patch(smtplib_string) def test_hook_conn(self, mock_smtplib, mock_hook_conn): mock_conn = Mock() - mock_conn.login = "user" - mock_conn.password = "password" - mock_conn.extra_dejson = { - "disable_ssl": False, - } + mock_conn.login = SMTP_LOGIN + mock_conn.password = SMTP_PASSWORD + mock_conn.extra_dejson = {"disable_ssl": False} mock_hook_conn.return_value = mock_conn smtp_client_mock = mock_smtplib.SMTP_SSL() with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - mock_hook_conn.assert_called_with("smtp_default") - smtp_client_mock.login.assert_called_once_with("user", "password") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + + mock_hook_conn.assert_called_with(CONN_ID_DEFAULT) + smtp_client_mock.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD) smtp_client_mock.sendmail.assert_called_once() assert smtp_client_mock.close.called + @pytest.mark.parametrize( + "conn_id, ssl_context, create_context_called, use_default_context", + [ + pytest.param(CONN_ID_DEFAULT, "default", True, True, id="default_context"), + pytest.param(CONN_ID_SSL_EXTRA, "none", False, False, id="none_context"), + pytest.param(CONN_ID_DEFAULT, "default", True, True, id="explicit_default_context"), + ], + ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl): - mock_smtp_ssl.return_value = Mock() - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - assert not mock_smtp.called - assert create_default_context.called - mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value - ) - - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_extra_none_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db + def test_send_mime_ssl_context( + self, + create_default_context, + mock_smtp, + mock_smtp_ssl, + conn_id, + ssl_context, + create_context_called, + use_default_context, ): mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email="from")), - ) - create_connection_without_db(conn) - with SmtpHook(smtp_conn_id="smtp_ssl_extra") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - assert not mock_smtp.called - create_default_context.assert_not_called() - mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None) - @patch("smtplib.SMTP_SSL") - @patch("smtplib.SMTP") - @patch("ssl.create_default_context") - def test_send_mime_ssl_extra_default_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db - ): - mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, ssl_context="default", from_email="from")), - ) - create_connection_without_db(conn) - with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + with SmtpHook(conn_id) as smtp_hook: + smtp_hook.send_email_smtp( + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL + ) + assert not mock_smtp.called - assert create_default_context.called + if use_default_context: + assert create_default_context.called + expected_context = create_default_context.return_value + else: + create_default_context.assert_not_called() + expected_context = None + mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value + host=SMTP_HOST, port=DEFAULT_PORT, timeout=DEFAULT_TIMEOUT, context=expected_context ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") @patch("ssl.create_default_context") - def test_send_mime_default_context( - self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db - ): + def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl): mock_smtp_ssl.return_value = Mock() - conn = Connection( - conn_id="smtp_ssl_extra", - conn_type="smtp", - host="smtp_server_address", - login=None, - password="None", - port=465, - extra=json.dumps(dict(use_ssl=True, from_email="from")), - ) - create_connection_without_db(conn) + with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + assert not mock_smtp.called assert create_default_context.called mock_smtp_ssl.assert_called_once_with( - host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value + host=SMTP_HOST, + port=DEFAULT_PORT, + timeout=DEFAULT_TIMEOUT, + context=create_default_context.return_value, ) @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") def test_send_mime_nossl(self, mock_smtp, mock_smtp_ssl): mock_smtp.return_value = Mock() - with SmtpHook(smtp_conn_id="smtp_nonssl") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") + + with SmtpHook(smtp_conn_id=CONN_ID_NONSSL) as smtp_hook: + smtp_hook.send_email_smtp( + to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL + ) + assert not mock_smtp_ssl.called - mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) + mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT) @patch("smtplib.SMTP") def test_send_mime_noauth(self, mock_smtp, create_connection_without_db): mock_smtp.return_value = Mock() conn = Connection( conn_id="smtp_noauth", - conn_type="smtp", - host="smtp_server_address", + conn_type=CONN_TYPE, + host=SMTP_HOST, login=None, password="None", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) create_connection_without_db(conn) with SmtpHook(smtp_conn_id="smtp_noauth") as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from") - mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30) + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, + ) + mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT) assert not mock_smtp.login.called @patch("smtplib.SMTP_SSL") @patch("smtplib.SMTP") def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", dryrun=True) + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + dryrun=True, + ) + assert not mock_smtp.sendmail.called assert not mock_smtp_ssl.sendmail.called @patch("smtplib.SMTP_SSL") def test_send_mime_ssl_complete_failure(self, mock_smtp_ssl): mock_smtp_ssl().sendmail.side_effect = smtplib.SMTPServerDisconnected() + with SmtpHook() as smtp_hook: with pytest.raises(smtplib.SMTPServerDisconnected): - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") - assert mock_smtp_ssl().sendmail.call_count == 5 + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + + assert mock_smtp_ssl().sendmail.call_count == DEFAULT_RETRY_LIMIT @patch("email.message.Message.as_string") @patch("smtplib.SMTP_SSL") @@ -350,10 +397,14 @@ def test_send_mime_partial_failure(self, mock_smtp_ssl, mime_message_mock): side_effects = [smtplib.SMTPServerDisconnected(), smtplib.SMTPServerDisconnected(), final_mock] mock_smtp_ssl.side_effect = side_effects with SmtpHook() as smtp_hook: - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") + smtp_hook.send_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) assert mock_smtp_ssl.call_count == side_effects.index(final_mock) + 1 assert final_mock.starttls.called - final_mock.sendmail.assert_called_once_with(from_addr="from", to_addrs=["to"], msg="msg") + final_mock.sendmail.assert_called_once_with(from_addr=FROM_EMAIL, to_addrs=[TO_EMAIL], msg="msg") assert final_mock.close.called @patch("smtplib.SMTP_SSL") @@ -366,18 +417,20 @@ def test_send_mime_custom_timeout_retrylimit( custom_timeout = 60 fake_conn = Connection( conn_id="mock_conn", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=465, - extra=json.dumps(dict(from_email="from", timeout=custom_timeout, retry_limit=custom_retry_limit)), + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps( + dict(from_email=FROM_EMAIL, timeout=custom_timeout, retry_limit=custom_retry_limit) + ), ) create_connection_without_db(fake_conn) with SmtpHook(smtp_conn_id="mock_conn") as smtp_hook: with pytest.raises(smtplib.SMTPServerDisconnected): - smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content") + smtp_hook.send_email_smtp(to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY) expected_call = call( host=fake_conn.host, @@ -393,18 +446,18 @@ def test_send_mime_custom_timeout_retrylimit( def test_oauth2_auth_called(self, mock_smtplib): mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False) - with SmtpHook(smtp_conn_id="smtp_oauth2", auth_type="oauth2") as smtp_hook: + with SmtpHook(smtp_conn_id=CONN_ID_OAUTH, auth_type="oauth2") as smtp_hook: smtp_hook.send_email_smtp( - to="to@example.com", - subject="subject", - html_content="content", - from_email="from", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, ) assert mock_conn.auth.called args, _ = mock_conn.auth.call_args assert args[0] == "XOAUTH2" - assert build_xoauth2_string("smtp_user", "test-token") == args[1]() + assert build_xoauth2_string(SMTP_LOGIN, ACCESS_TOKEN) == args[1]() @patch(smtplib_string) def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_without_db): @@ -413,22 +466,187 @@ def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_witho create_connection_without_db( Connection( conn_id="smtp_oauth2_empty", - conn_type="smtp", - host="smtp_server_address", - login="smtp_user", - password="smtp_password", - port=587, - extra=json.dumps(dict(disable_ssl=True, from_email="from")), + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), ) ) with pytest.raises(AirflowException): with SmtpHook(smtp_conn_id="smtp_oauth2_empty", auth_type="oauth2") as h: h.send_email_smtp( - to="to@example.com", - subject="subject", - html_content="content", - from_email="from", + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + from_email=FROM_EMAIL, ) assert not mock_conn.auth.called + + +@pytest.mark.asyncio +@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0") +class TestSmtpHookAsync: + """Tests for async functionality in SmtpHook.""" + + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + create_connection_without_db( + Connection( + conn_id=CONN_ID_DEFAULT, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=DEFAULT_PORT, + extra=json.dumps(dict(from_email=FROM_EMAIL, ssl_context="default")), + ) + ) + create_connection_without_db( + Connection( + conn_id=CONN_ID_NONSSL, + conn_type=CONN_TYPE, + host=SMTP_HOST, + login=SMTP_LOGIN, + password=SMTP_PASSWORD, + port=NONSSL_PORT, + extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)), + ) + ) + + @pytest.fixture + def mock_smtp_client(self): + """Create a mock SMTP client with async capabilities.""" + mock_client = AsyncMock(spec=aiosmtplib.SMTP) + mock_client.starttls = AsyncMock() + mock_client.auth_login = AsyncMock() + mock_client.sendmail = AsyncMock() + mock_client.quit = AsyncMock() + return mock_client + + @pytest.fixture + def mock_smtp(self, mock_smtp_client): + """Set up the SMTP mock with context manager.""" + with mock.patch("airflow.providers.smtp.hooks.smtp.aiosmtplib.SMTP") as mock_smtp: + mock_smtp.return_value = mock_smtp_client + yield mock_smtp + + @pytest.fixture + def mock_get_connection(self): + """Mock the async connection retrieval.""" + with mock.patch("airflow.sdk.bases.hook.BaseHook.aget_connection") as mock_conn: + + async def async_get_connection(conn_id): + from airflow.sdk.definitions.connection import Connection + + return Connection.from_json(os.environ[f"AIRFLOW_CONN_{conn_id.upper()}"]) + + mock_conn.side_effect = async_get_connection + yield mock_conn + + @staticmethod + def _create_fake_async_smtp(mock_smtp): + mock_client = AsyncMock(spec=aiosmtplib.SMTP) + mock_client.starttls = AsyncMock() + mock_client.auth_login = AsyncMock() + mock_client.sendmail = AsyncMock() + mock_client.quit = AsyncMock() + mock_smtp.return_value = mock_client + return mock_client + + @pytest.mark.parametrize( + "conn_id, expected_port, expected_ssl", + [ + pytest.param(CONN_ID_NONSSL, NONSSL_PORT, False, id="non-ssl-connection"), + pytest.param(CONN_ID_DEFAULT, DEFAULT_PORT, True, id="ssl-connection"), + ], + ) + async def test_async_connection( + self, mock_smtp, mock_smtp_client, mock_get_connection, conn_id, expected_port, expected_ssl + ): + """Test async connection with different configurations.""" + async with SmtpHook(smtp_conn_id=conn_id) as hook: + assert hook is not None + + mock_smtp.assert_called_once_with( + hostname=SMTP_HOST, + port=expected_port, + timeout=DEFAULT_TIMEOUT, + use_tls=expected_ssl, + start_tls=None if expected_ssl else True, + ) + + if expected_ssl: + assert mock_smtp_client.starttls.await_count == 1 + + assert mock_smtp_client.auth_login.await_count == 1 + mock_smtp_client.auth_login.assert_awaited_once_with(SMTP_LOGIN, SMTP_PASSWORD) + + @pytest.mark.asyncio + async def test_async_send_email(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending functionality.""" + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + + assert mock_smtp_client.sendmail.called + # The async version of sendmail only supports positional arguments + # for some reason, so we have to check these by positional args + call_args = mock_smtp_client.sendmail.await_args.args + assert call_args[0] == FROM_EMAIL # sender is first positional arg + assert call_args[1] == [TO_EMAIL] # recipients is the second positional arg + assert f"Subject: {TEST_SUBJECT}" in call_args[2] # message is the third positional arg + + @pytest.mark.parametrize( + "side_effect, expected_calls, should_raise", + [ + pytest.param( + [SERVER_DISCONNECTED_ERROR, SERVER_DISCONNECTED_ERROR, None], + 3, + False, + id="success_after_retries", + ), + pytest.param(SERVER_DISCONNECTED_ERROR, DEFAULT_RETRY_LIMIT, True, id="max_retries_exceeded"), + ], + ) + @pytest.mark.asyncio + async def test_async_send_email_retries( + self, mock_smtp, mock_smtp_client, mock_get_connection, side_effect, expected_calls, should_raise + ): + mock_smtp_client.sendmail.side_effect = side_effect + + if should_raise: + with pytest.raises(aiosmtplib.errors.SMTPServerDisconnected): + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + else: + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + ) + + assert mock_smtp_client.sendmail.call_count == expected_calls + + async def test_async_send_email_dryrun(self, mock_smtp, mock_smtp_client, mock_get_connection): + """Test async email sending in dryrun mode.""" + async with SmtpHook() as hook: + await hook.asend_email_smtp( + to=TO_EMAIL, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + dryrun=True, + ) + + mock_smtp_client.sendmail.assert_not_awaited() diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py index 4dfffab36e03e..6f8fcf60ac848 100644 --- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py +++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py @@ -18,126 +18,157 @@ from __future__ import annotations import tempfile +from dataclasses import dataclass from unittest import mock +from unittest.mock import AsyncMock -from airflow.providers.smtp.hooks.smtp import SmtpHook -from airflow.providers.smtp.notifications.smtp import ( - SmtpNotifier, - send_smtp_notification, -) +import pytest -SMTP_API_DEFAULT_CONN_ID = SmtpHook.default_conn_name +from airflow.providers.smtp.notifications.smtp import SmtpNotifier, send_smtp_notification +from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS -NUM_TRY = 0 +TRY_NUMBER = 0 + +SMTP_CONN_ID = "smtp_default" +SMTP_AUTH_TYPE = "oauth2" + +# Standard settings +DEFAULT_EMAIL_PARAMS = { + "mime_subtype": "mixed", + "mime_charset": "utf-8", + "files": None, + "cc": None, + "bcc": None, + "custom_headers": None, +} + +# DAG settings +TEST_DAG_ID = "test_dag" +TEST_TASK_ID = "test_task" +TEST_TASK_STATE = None +TEST_RUN_ID = "test_run" + +# Jinja template patterns +DAG_ID_TEMPLATE_STRING = "{{dag.dag_id}}" +TI_TEMPLATE_STRING = "{{ti.task_id}}" + +SENDER_EMAIL_SUFFIX = "sender@test.com" +RECEIVER_EMAIL_SUFFIX = "receiver@test.com" +# Base test values +TEST_SENDER = f"test_{SENDER_EMAIL_SUFFIX}" +TEST_RECEIVER = f"test_{RECEIVER_EMAIL_SUFFIX}" +TEST_SUBJECT = "subject" +TEST_BODY = "body" + +NOTIFIER_DEFAULT_PARAMS = { + "from_email": TEST_SENDER, + "to": TEST_RECEIVER, + "subject": TEST_SUBJECT, + "html_content": TEST_BODY, +} + + +# Templated versions +@dataclass(frozen=True) +class TemplatedString: + template: str + + @property + def rendered(self) -> str: + return self.template.replace(DAG_ID_TEMPLATE_STRING, TEST_DAG_ID).replace( + TI_TEMPLATE_STRING, TEST_TASK_ID + ) + + +# DAG-based templates +TEMPLATED_SENDER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}") +TEMPLATED_RECEIVER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{RECEIVER_EMAIL_SUFFIX}") +TEMPLATED_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {DAG_ID_TEMPLATE_STRING}") +TEMPLATED_BODY = TemplatedString(f"{TEST_BODY} {DAG_ID_TEMPLATE_STRING}") + +# Task-based templates +TEMPLATED_TI_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {TI_TEMPLATE_STRING}") +TEMPLATED_TI_SENDER = TemplatedString(f"{TI_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}") class TestSmtpNotifier: @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier(_self, mock_smtphook_hook, create_dag_without_db): - notifier = send_smtp_notification( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - ) - notifier({"dag": create_dag_without_db("test_notifier")}) + notifier = send_smtp_notification(**NOTIFIER_DEFAULT_PARAMS) + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + **NOTIFIER_DEFAULT_PARAMS, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_notifier_class(self, mock_smtphook_hook, create_dag_without_db): - notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - ) - notifier({"dag": create_dag_without_db("test_notifier")}) + notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS) + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com", - to="test_reciver@test.com", - subject="subject", - html_content="body", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_templated(self, mock_smtphook_hook, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com {{dag.dag_id}}", - to="test_reciver@test.com {{dag.dag_id}}", - subject="subject {{dag.dag_id}}", - html_content="body {{dag.dag_id}}", + from_email=TEMPLATED_SENDER.template, + to=TEMPLATED_RECEIVER.template, + subject=TEMPLATED_SUBJECT.template, + html_content=TEMPLATED_BODY.template, ) - context = {"dag": create_dag_without_db("test_notifier")} + context = {"dag": create_dag_without_db(TEST_DAG_ID)} + notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="test_sender@test.com test_notifier", - to="test_reciver@test.com test_notifier", - subject="subject test_notifier", - html_content="body test_notifier", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEMPLATED_SENDER.rendered, + to=TEMPLATED_RECEIVER.rendered, + subject=TEMPLATED_SUBJECT.rendered, + html_content=TEMPLATED_BODY.rendered, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_defaults(self, mock_smtphook_hook, create_dag_without_db, mock_task_instance): # TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+ mock_ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test", - try_number=NUM_TRY, + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, + try_number=TRY_NUMBER, max_tries=0, - state=None, + state=TEST_TASK_STATE, ) - context = {"dag": create_dag_without_db("test_dag"), "ti": mock_ti} + context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": mock_ti} notifier = SmtpNotifier( - from_email="any email", - to="test_reciver@test.com", + from_email=TEST_SENDER, + to=TEST_RECEIVER, ) mock_smtphook_hook.return_value.__enter__.return_value.subject_template = None mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = None + notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="any email", - to="test_reciver@test.com", - subject="DAG test_dag - Task op - Run ID test in State None", + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=f"DAG {TEST_DAG_ID} - Task {TEST_TASK_ID} - Run ID {TEST_RUN_ID} in State {TEST_TASK_STATE}", html_content=mock.ANY, - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) content = mock_smtphook_hook.return_value.__enter__().send_email_smtp.call_args.kwargs["html_content"] - assert f"{NUM_TRY} of 1" in content + assert f"{TRY_NUMBER} of 1" in content @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_with_nondefault_connection_extra( @@ -145,58 +176,159 @@ def test_notifier_with_nondefault_connection_extra( ): # TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+ ti = mock_task_instance( - dag_id="test_dag", - task_id="op", - run_id="test_run", - try_number=NUM_TRY, + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, + try_number=TRY_NUMBER, max_tries=0, - state=None, + state=TEST_TASK_STATE, ) - context = {"dag": create_dag_without_db("test_dag"), "ti": ti} + context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": ti} with ( tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, ): - f_subject.write("Task {{ ti.task_id }} failed") + f_subject.write(TEMPLATED_TI_SUBJECT.template) f_subject.flush() - f_content.write("Mock content goes here") + f_content.write(TEST_BODY) f_content.flush() - mock_smtphook_hook.return_value.__enter__.return_value.from_email = "{{ ti.task_id }}@test.com" + mock_smtphook_hook.return_value.__enter__.return_value.from_email = TEMPLATED_TI_SENDER.template mock_smtphook_hook.return_value.__enter__.return_value.subject_template = f_subject.name mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = f_content.name - notifier = SmtpNotifier( - to="test_reciver@test.com", - ) + + notifier = SmtpNotifier(to=TEST_RECEIVER) notifier(context) + mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with( - from_email="op@test.com", - to="test_reciver@test.com", - subject="Task op failed", - html_content="Mock content goes here", - smtp_conn_id="smtp_default", - files=None, - cc=None, - bcc=None, - mime_subtype="mixed", - mime_charset="utf-8", - custom_headers=None, + from_email=TEMPLATED_TI_SENDER.rendered, + to=TEST_RECEIVER, + subject=TEMPLATED_TI_SUBJECT.rendered, + html_content=TEST_BODY, + smtp_conn_id=SMTP_CONN_ID, + **DEFAULT_EMAIL_PARAMS, ) @mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_without_db): + notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS, auth_type=SMTP_AUTH_TYPE) + + notifier({"dag": create_dag_without_db(TEST_DAG_ID)}) + + mock_smtphook_hook.assert_called_once_with(smtp_conn_id=SMTP_CONN_ID, auth_type=SMTP_AUTH_TYPE) + + +@pytest.mark.asyncio +@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0") +class TestSmtpNotifierAsync: + @pytest.fixture + def mock_smtp_client(self): + """Create a mock SMTP object with async capabilities.""" + mock_smtp = AsyncMock() + mock_smtp.asend_email_smtp = AsyncMock() + return mock_smtp + + @pytest.fixture + def mock_smtp_hook(self, mock_smtp_client): + """Set up the SMTP hook with async context manager.""" + with mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") as mock_hook: + mock_hook.return_value.__aenter__ = AsyncMock(return_value=mock_smtp_client) + yield mock_hook + + @pytest.mark.asyncio + async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): notifier = SmtpNotifier( - from_email="test_sender@test.com", - to="test_reciver@test.com", - auth_type="oauth2", - subject="subject", - html_content="body", + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} ) + await notifier.async_notify({"dag": create_dag_without_db(TEST_DAG_ID)}) - notifier({"dag": create_dag_without_db("test_notifier")}) + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id=SMTP_CONN_ID, + **NOTIFIER_DEFAULT_PARAMS, + **DEFAULT_EMAIL_PARAMS, + ) - mock_smtphook_hook.assert_called_once_with( - smtp_conn_id="smtp_default", - auth_type="oauth2", + async def test_async_notifier_with_notifier_class( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db + ): + notifier = SmtpNotifier( + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id=SMTP_CONN_ID, + **NOTIFIER_DEFAULT_PARAMS, + **DEFAULT_EMAIL_PARAMS, + ) + + async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db): + notifier = SmtpNotifier( + from_email=TEMPLATED_SENDER.template, + to=TEMPLATED_RECEIVER.template, + subject=TEMPLATED_SUBJECT.template, + html_content=TEMPLATED_BODY.template, + context={"dag": create_dag_without_db(TEST_DAG_ID)}, + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id=SMTP_CONN_ID, + from_email=TEMPLATED_SENDER.rendered, + to=TEMPLATED_RECEIVER.rendered, + subject=TEMPLATED_SUBJECT.rendered, + html_content=TEMPLATED_BODY.rendered, + **DEFAULT_EMAIL_PARAMS, + ) + + async def test_async_notifier_with_defaults( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance + ): + mock_smtp_client.subject_template = None + mock_smtp_client.html_content_template = None + mock_smtp_client.from_email = None + + notifier = SmtpNotifier( + **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)} + ) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id=SMTP_CONN_ID, + **NOTIFIER_DEFAULT_PARAMS, + **DEFAULT_EMAIL_PARAMS, + ) + + async def test_async_notifier_with_nondefault_connection_extra( + self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance + ): + with ( + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject, + tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content, + ): + f_subject.write(TEST_SUBJECT) + f_subject.flush() + + f_content.write(TEST_BODY) + f_content.flush() + + mock_smtp_client.from_email = TEST_SENDER + mock_smtp_client.subject_template = f_subject.name + mock_smtp_client.html_content_template = f_content.name + + notifier = SmtpNotifier(to=TEST_RECEIVER, context={"dag": create_dag_without_db(TEST_DAG_ID)}) + + await notifier + + mock_smtp_client.asend_email_smtp.assert_called_once_with( + smtp_conn_id=SMTP_CONN_ID, + from_email=TEST_SENDER, + to=TEST_RECEIVER, + subject=TEST_SUBJECT, + html_content=TEST_BODY, + **DEFAULT_EMAIL_PARAMS, + )