diff --git a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
index d1125224ecf58..7b153ea02d361 100644
--- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
+++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
@@ -21,6 +21,7 @@
from typing import TYPE_CHECKING
from airflow.providers.common.compat.notifier import BaseNotifier
+from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS
from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook
if TYPE_CHECKING:
@@ -64,7 +65,11 @@ def __init__(
retry_handlers: list[RetryHandler] | None = None,
**kwargs,
):
- super().__init__(**kwargs)
+ if AIRFLOW_V_3_1_PLUS:
+ # Support for passing context was added in 3.1.0
+ super().__init__(**kwargs)
+ else:
+ super().__init__()
self.slack_webhook_conn_id = slack_webhook_conn_id
self.text = text
self.attachments = attachments
diff --git a/providers/smtp/pyproject.toml b/providers/smtp/pyproject.toml
index ac31d92cb0f95..98bebb17ff874 100644
--- a/providers/smtp/pyproject.toml
+++ b/providers/smtp/pyproject.toml
@@ -59,6 +59,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.10.0",
"apache-airflow-providers-common-compat>=1.6.1",
+ "aiosmtplib>=0.1.6",
]
[dependency-groups]
diff --git a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
index 53982137dab51..d0bd9fb17bded 100644
--- a/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
+++ b/providers/smtp/src/airflow/providers/smtp/hooks/smtp.py
@@ -18,7 +18,8 @@
"""
Search in emails for a specific attachment and also to download it.
-It uses the smtplib library that is already integrated in python 3.
+It uses the smtplib library that is already integrated in python 3 for
+synchronous connections or aiosmtplib for async connections.
"""
from __future__ import annotations
@@ -33,16 +34,19 @@
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
+from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
+import aiosmtplib
+
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.smtp.version_compat import BaseHook
if TYPE_CHECKING:
try:
from airflow.sdk import Connection
- except ImportError:
+ except (ImportError, ModuleNotFoundError):
from airflow.models.connection import Connection # type: ignore[assignment]
@@ -71,15 +75,37 @@ def __init__(self, smtp_conn_id: str = default_conn_name, auth_type: str = "basi
super().__init__()
self.smtp_conn_id = smtp_conn_id
self.smtp_connection: Connection | None = None
- self.smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | None = None
+ self._smtp_client: smtplib.SMTP_SSL | smtplib.SMTP | aiosmtplib.SMTP | None = None
self._auth_type = auth_type
self._access_token: str | None = None
def __enter__(self) -> SmtpHook:
return self.get_conn()
+ async def __aenter__(self) -> SmtpHook:
+ return await self.aget_conn()
+
def __exit__(self, exc_type, exc_val, exc_tb):
- self.smtp_client.close()
+ self._smtp_client.close()
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ if self._smtp_client:
+ await self._smtp_client.quit()
+
+ def _setup_oauth2(self) -> tuple[str, str]:
+ """
+ Set up OAuth2 credentials and return token and user identity.
+
+ :return: Tuple of (user_identity, access_token)
+ """
+ if not self._access_token:
+ self._access_token = self._get_oauth2_token()
+
+ user_identity = self.smtp_user or self.from_email
+ if user_identity is None:
+ raise AirflowException("smtp_user or from_email must be set for OAuth2 authentication")
+
+ return user_identity, self._access_token
def get_conn(self) -> SmtpHook:
"""
@@ -90,7 +116,7 @@ def get_conn(self) -> SmtpHook:
:return: an authorized SmtpHook object.
"""
- if not self.smtp_client:
+ if not self._smtp_client:
try:
self.smtp_connection = self.get_connection(self.smtp_conn_id)
except AirflowNotFoundException:
@@ -98,13 +124,13 @@ def get_conn(self) -> SmtpHook:
for attempt in range(1, self.smtp_retry_limit + 1):
try:
- self.smtp_client = self._build_client()
+ self._smtp_client = self._build_client()
except smtplib.SMTPServerDisconnected:
if attempt == self.smtp_retry_limit:
raise AirflowException("Unable to connect to smtp server")
else:
if self.smtp_starttls:
- self.smtp_client.starttls()
+ self._smtp_client.starttls()
# choose auth
if self._auth_type == "oauth2":
@@ -115,41 +141,94 @@ def get_conn(self) -> SmtpHook:
raise AirflowException(
"smtp_user or from_email must be set for OAuth2 authentication"
)
- self.smtp_client.auth(
+ self._smtp_client.auth(
"XOAUTH2",
lambda _=None: build_xoauth2_string(user_identity, self._access_token),
)
elif self.smtp_user and self.smtp_password:
- self.smtp_client.login(self.smtp_user, self.smtp_password)
+ self._smtp_client.login(self.smtp_user, self.smtp_password)
break
return self
- def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP:
- SMTP: type[smtplib.SMTP_SSL] | type[smtplib.SMTP]
- if self.use_ssl:
- SMTP = smtplib.SMTP_SSL
- else:
- SMTP = smtplib.SMTP
+ async def aget_conn(self) -> SmtpHook:
+ """
+ Login to the smtp server (async).
+
+ .. note:: Please call this Hook as context manager via `with`
+ to automatically open and close the connection to the smtp server.
+
+ :return: an authorized SmtpHook object.
+ """
+ if not self._smtp_client:
+ try:
+ self.smtp_connection = await self.aget_connection(self.smtp_conn_id)
+ except AirflowNotFoundException:
+ raise AirflowException("SMTP connection is not found.")
+
+ for attempt in range(1, self.smtp_retry_limit + 1):
+ try:
+ async_client = await self._abuild_client()
+ self._smtp_client = async_client
+ except aiosmtplib.errors.SMTPServerDisconnected:
+ if attempt == self.smtp_retry_limit:
+ raise AirflowException("Unable to connect to smtp server")
+ else:
+ if self.smtp_starttls:
+ await async_client.starttls()
+
+ if self.smtp_user and self.smtp_password:
+ await async_client.auth_login(self.smtp_user, self.smtp_password)
+ break
+
+ return self
+
+ def _build_client_kwargs(self, is_async: bool) -> dict[str, Any]:
+ """Build kwargs appropriate for sync or async SMTP client."""
+ valid_contexts = (None, "default", "none") # Values accepted for ssl_context configuration
+
+ kwargs: dict[str, Any] = {"timeout": self.timeout}
- smtp_kwargs: dict[str, Any] = {"host": self.host}
if self.port:
- smtp_kwargs["port"] = self.port
- smtp_kwargs["timeout"] = self.timeout
-
- if self.use_ssl:
- ssl_context_string = self.ssl_context
- if ssl_context_string is None or ssl_context_string == "default":
- ssl_context = ssl.create_default_context()
- elif ssl_context_string == "none":
- ssl_context = None
- else:
- raise RuntimeError(
- f"The connection extra field `ssl_context` must "
- f"be set to 'default' or 'none' but it is set to '{ssl_context_string}'."
- )
- smtp_kwargs["context"] = ssl_context
- return SMTP(**smtp_kwargs)
+ kwargs["port"] = self.port
+
+ if is_async:
+ kwargs["hostname"] = self.host
+ kwargs["use_tls"] = self.use_ssl
+ kwargs["start_tls"] = self.smtp_starttls if not self.use_ssl else None
+ else:
+ kwargs["host"] = self.host
+ if self.use_ssl:
+ if self.ssl_context not in valid_contexts:
+ raise RuntimeError(
+ f"The connection extra field `ssl_context` must "
+ f"be set to 'default' or 'none' but it is set to '{self.ssl_context}'."
+ )
+ kwargs["context"] = None if self.ssl_context == "none" else ssl.create_default_context()
+
+ return kwargs
+
+ def _build_client(self) -> smtplib.SMTP_SSL | smtplib.SMTP:
+ """Build a synchronous SMTP client."""
+ client: type[smtplib.SMTP_SSL] | type[smtplib.SMTP] = (
+ smtplib.SMTP_SSL if self.use_ssl else smtplib.SMTP
+ )
+ return client(**self._build_client_kwargs(is_async=False))
+
+ async def _abuild_client(self) -> aiosmtplib.SMTP:
+ """
+ Build an asynchronous SMTP client.
+
+ Unlike the synchronous client (which connects automatically when instantiated),
+ aiosmtplib requires explicit connect() and ehlo() calls. We handle those here
+ to keep the async implementation details contained and make aget_conn behavior
+ match get_conn.
+ """
+ async_client = aiosmtplib.SMTP(**self._build_client_kwargs(is_async=True))
+ await async_client.connect()
+ await async_client.ehlo()
+
+ return async_client
@classmethod
def get_connection_form_widgets(cls) -> dict[str, Any]:
@@ -198,15 +277,82 @@ def get_connection_form_widgets(cls) -> dict[str, Any]:
def test_connection(self) -> tuple[bool, str]:
"""Test SMTP connectivity from UI."""
try:
- smtp_client = self.get_conn().smtp_client
+ smtp_client = self.get_conn()._smtp_client
if smtp_client:
- status = smtp_client.noop()[0]
+ status = smtp_client.noop()
if status == 250:
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
return False, "Failed to establish connection"
+ async def atest_connection(self) -> tuple[bool, str]:
+ """Test SMTP connectivity (async)."""
+ try:
+ smtp_client = (await self.aget_conn())._smtp_client
+ if smtp_client is None:
+ return False, "SMTP client not initialized"
+
+ if isinstance(smtp_client, aiosmtplib.SMTP):
+ async_client: aiosmtplib.SMTP = smtp_client
+ response = await async_client.noop()
+ if response.code == 250:
+ return True, "Async connection successfully tested"
+ return False, f"Connection test failed with code: {response.code}"
+
+ except Exception as e:
+ return False, str(e)
+ return False, "Failed to establish connection"
+
+ def _build_message(
+ self,
+ to: str | Iterable[str],
+ subject: str | None = None,
+ html_content: str | None = None,
+ from_email: str | None = None,
+ files: list[str] | None = None,
+ cc: str | Iterable[str] | None = None,
+ bcc: str | Iterable[str] | None = None,
+ mime_subtype: str = "mixed",
+ mime_charset: str = "utf-8",
+ custom_headers: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ if not self._smtp_client:
+ raise AirflowException("The 'smtp_client' should be initialized before!")
+
+ from_email = from_email or self.from_email
+ if not from_email:
+ raise AirflowException("You should provide `from_email` or define it in the connection.")
+
+ if not subject:
+ if self.subject_template is None:
+ raise AirflowException(
+ "You should provide `subject` or define `subject_template` in the connection."
+ )
+ subject = self._read_template(self.subject_template)
+
+ if not html_content:
+ if self.html_content_template is None:
+ raise AirflowException(
+ "You should provide `html_content` or define `html_content_template` in the connection."
+ )
+ html_content = self._read_template(self.html_content_template)
+
+ mime_msg, recipients = self._build_mime_message(
+ mail_from=from_email,
+ to=to,
+ subject=subject,
+ html_content=html_content,
+ files=files,
+ cc=cc,
+ bcc=bcc,
+ mime_subtype=mime_subtype,
+ mime_charset=mime_charset,
+ custom_headers=custom_headers,
+ )
+
+ return {"mime_msg": mime_msg, "recipients": recipients, "from_email": from_email}
+
def send_email_smtp(
self,
*,
@@ -246,27 +392,9 @@ def send_email_smtp(
'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True
)
"""
- if not self.smtp_client:
- raise AirflowException("The 'smtp_client' should be initialized before!")
- from_email = from_email or self.from_email
- if not from_email:
- raise AirflowException("You should provide `from_email` or define it in the connection.")
- if not subject:
- if self.subject_template is None:
- raise AirflowException(
- "You should provide `subject` or define `subject_template` in the connection."
- )
- subject = self._read_template(self.subject_template)
- if not html_content:
- if self.html_content_template is None:
- raise AirflowException(
- "You should provide `html_content` or define `html_content_template` in the connection."
- )
- html_content = self._read_template(self.html_content_template)
-
- mime_msg, recipients = self._build_mime_message(
- mail_from=from_email,
+ msg = self._build_message(
to=to,
+ from_email=from_email,
subject=subject,
html_content=html_content,
files=files,
@@ -277,16 +405,92 @@ def send_email_smtp(
custom_headers=custom_headers,
)
if not dryrun:
+ if self._smtp_client is None:
+ raise AirflowException("The SMTP client is not initialized")
+ # Casting here to make MyPy happy.
+ smtp_client = cast("smtplib.SMTP_SSL | smtplib.SMTP", self._smtp_client)
+
for attempt in range(1, self.smtp_retry_limit + 1):
try:
- self.smtp_client.sendmail(
- from_addr=from_email, to_addrs=recipients, msg=mime_msg.as_string()
+ smtp_client.sendmail(
+ from_addr=msg["from_email"],
+ to_addrs=msg["recipients"],
+ msg=msg["mime_msg"].as_string(),
)
- except smtplib.SMTPServerDisconnected as e:
+ break
+ except Exception as e:
if attempt == self.smtp_retry_limit:
raise e
- else:
+
+ async def asend_email_smtp(
+ self,
+ *,
+ to: str | Iterable[str],
+ subject: str | None = None,
+ html_content: str | None = None,
+ from_email: str | None = None,
+ files: list[str] | None = None,
+ dryrun: bool = False,
+ cc: str | Iterable[str] | None = None,
+ bcc: str | Iterable[str] | None = None,
+ mime_subtype: str = "mixed",
+ mime_charset: str = "utf-8",
+ custom_headers: dict[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ """
+ Send an email with html content.
+
+ :param to: Recipient email address or list of addresses.
+ :param subject: Email subject. If it's None, the hook will check if there is a path to a subject
+ file provided in the connection, and raises an exception if not.
+ :param html_content: Email body in HTML format. If it's None, the hook will check if there is a path
+ to a html content file provided in the connection, and raises an exception if not.
+ :param from_email: Sender email address. If it's None, the hook will check if there is an email
+ provided in the connection, and raises an exception if not.
+ :param files: List of file paths to attach to the email.
+ :param dryrun: If True, the email will not be sent, but all other actions will be performed.
+ :param cc: Carbon copy recipient email address or list of addresses.
+ :param bcc: Blind carbon copy recipient email address or list of addresses.
+ :param mime_subtype: MIME subtype of the email.
+ :param mime_charset: MIME charset of the email.
+ :param custom_headers: Dictionary of custom headers to include in the email.
+ :param kwargs: Additional keyword arguments.
+
+ >>> send_email_smtp(
+ 'test@example.com', 'foo', 'Foo bar', ['/dev/null'], dryrun=True
+ )
+ """
+ msg = self._build_message(
+ to=to,
+ subject=subject,
+ html_content=html_content,
+ from_email=from_email,
+ files=files,
+ cc=cc,
+ bcc=bcc,
+ mime_subtype=mime_subtype,
+ mime_charset=mime_charset,
+ custom_headers=custom_headers,
+ )
+ if self._smtp_client is None:
+ raise AirflowException("The SMTP client is not initialized")
+ # Casting here to make MyPy happy.
+ smtp_client = cast("aiosmtplib.SMTP", self._smtp_client)
+
+ if not dryrun:
+ for attempt in range(1, self.smtp_retry_limit + 1):
+ try:
+ # The async version of sendmail only supports positional arguments for some reason.
+ await smtp_client.sendmail(
+ msg["from_email"],
+ msg["recipients"],
+ msg["mime_msg"].as_string(),
+ )
break
+ except Exception as e:
+ if attempt == self.smtp_retry_limit:
+ raise e
def _build_mime_message(
self,
@@ -420,7 +624,7 @@ def _get_oauth2_token(self) -> str:
"auth_type='oauth2' but neither 'access_token' nor client credentials supplied in connection extra."
)
- @property
+ @cached_property
def conn(self) -> Connection:
if not self.smtp_connection:
raise AirflowException("The smtp connection should be loaded before!")
diff --git a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py
index 49d463f8819e5..c77bc25e731b1 100644
--- a/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py
+++ b/providers/smtp/src/airflow/providers/smtp/notifications/smtp.py
@@ -20,11 +20,16 @@
from collections.abc import Iterable
from functools import cached_property
from pathlib import Path
-from typing import Any
+from typing import TYPE_CHECKING, Any
from airflow.providers.common.compat.notifier import BaseNotifier
from airflow.providers.smtp.hooks.smtp import SmtpHook
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
class SmtpNotifier(BaseNotifier):
"""
@@ -80,8 +85,13 @@ def __init__(
auth_type: str = "basic",
*,
template: str | None = None,
+ **kwargs,
):
- super().__init__()
+ if AIRFLOW_V_3_1_PLUS:
+ # Support for passing context was added in 3.1.0
+ super().__init__(**kwargs)
+ else:
+ super().__init__()
self.smtp_conn_id = smtp_conn_id
self.from_email = from_email
self.to = to
@@ -106,40 +116,61 @@ def hook(self) -> SmtpHook:
"""Smtp Events Hook."""
return SmtpHook(smtp_conn_id=self.smtp_conn_id, auth_type=self.auth_type)
- def notify(self, context):
+ def _build_email_content(self, smtp: SmtpHook, context: Context):
+ fields_to_re_render = []
+ if self.from_email is None:
+ if smtp.from_email is not None:
+ self.from_email = smtp.from_email
+ else:
+ raise ValueError("You should provide `from_email` or define it in the connection")
+ fields_to_re_render.append("from_email")
+ if self.subject is None:
+ smtp_default_templated_subject_path: str
+ if smtp.subject_template:
+ smtp_default_templated_subject_path = smtp.subject_template
+ else:
+ smtp_default_templated_subject_path = (
+ Path(__file__).parent / "templates" / "email_subject.jinja2"
+ ).as_posix()
+ self.subject = self._read_template(smtp_default_templated_subject_path)
+ fields_to_re_render.append("subject")
+ if self.html_content is None:
+ smtp_default_templated_html_content_path: str
+ if smtp.html_content_template:
+ smtp_default_templated_html_content_path = smtp.html_content_template
+ else:
+ smtp_default_templated_html_content_path = (
+ Path(__file__).parent / "templates" / "email.html"
+ ).as_posix()
+ self.html_content = self._read_template(smtp_default_templated_html_content_path)
+ fields_to_re_render.append("html_content")
+ if fields_to_re_render:
+ jinja_env = self.get_template_env(dag=context.get("dag"))
+ self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set())
+
+ def notify(self, context: Context):
"""Send a email via smtp server."""
- with self.hook as smtp:
- fields_to_re_render = []
- if self.from_email is None:
- if smtp.from_email is not None:
- self.from_email = smtp.from_email
- else:
- raise ValueError("You should provide `from_email` or define it in the connection")
- fields_to_re_render.append("from_email")
- if self.subject is None:
- smtp_default_templated_subject_path: str
- if smtp.subject_template:
- smtp_default_templated_subject_path = smtp.subject_template
- else:
- smtp_default_templated_subject_path = (
- Path(__file__).parent / "templates" / "email_subject.jinja2"
- ).as_posix()
- self.subject = self._read_template(smtp_default_templated_subject_path)
- fields_to_re_render.append("subject")
- if self.html_content is None:
- smtp_default_templated_html_content_path: str
- if smtp.html_content_template:
- smtp_default_templated_html_content_path = smtp.html_content_template
- else:
- smtp_default_templated_html_content_path = (
- Path(__file__).parent / "templates" / "email.html"
- ).as_posix()
- self.html_content = self._read_template(smtp_default_templated_html_content_path)
- fields_to_re_render.append("html_content")
- if fields_to_re_render:
- jinja_env = self.get_template_env(dag=context["dag"])
- self._do_render_template_fields(self, fields_to_re_render, context, jinja_env, set())
- smtp.send_email_smtp(
+ with self.hook as smtp_hook:
+ self._build_email_content(smtp_hook, context)
+ smtp_hook.send_email_smtp(
+ smtp_conn_id=self.smtp_conn_id,
+ from_email=self.from_email,
+ to=self.to,
+ subject=self.subject,
+ html_content=self.html_content,
+ files=self.files,
+ cc=self.cc,
+ bcc=self.bcc,
+ mime_subtype=self.mime_subtype,
+ mime_charset=self.mime_charset,
+ custom_headers=self.custom_headers,
+ )
+
+ async def async_notify(self, context: Context):
+ """Send a email via smtp server (async)."""
+ async with self.hook as smtp_hook:
+ self._build_email_content(smtp_hook, context)
+ await smtp_hook.asend_email_smtp(
smtp_conn_id=self.smtp_conn_id,
from_email=self.from_email,
to=self.to,
diff --git a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
index fc09281a306fc..e3fad4d7ae3d0 100644
--- a/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
+++ b/providers/smtp/tests/unit/smtp/hooks/test_smtp.py
@@ -22,17 +22,49 @@
import smtplib
import tempfile
from email.mime.application import MIMEApplication
-from unittest.mock import Mock, call, patch
+from unittest import mock
+from unittest.mock import AsyncMock, Mock, call, patch
+import aiosmtplib
import pytest
from airflow.exceptions import AirflowException
-from airflow.models import Connection
from airflow.providers.smtp.hooks.smtp import SmtpHook, build_xoauth2_string
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
+
+if AIRFLOW_V_3_1_PLUS:
+ from airflow.sdk import Connection
+else:
+ from airflow.models.connection import Connection # type: ignore[assignment]
+
smtplib_string = "airflow.providers.smtp.hooks.smtp.smtplib"
-TEST_EMAILS = ["test1@example.com", "test2@example.com"]
+FROM_EMAIL = "from@example.com"
+TO_EMAIL = "to@example.com"
+TEST_EMAILS = [FROM_EMAIL, TO_EMAIL]
+
+CONN_TYPE = "smtp"
+SMTP_HOST = "smtp.example.com"
+SMTP_LOGIN = "smtp_user"
+SMTP_PASSWORD = "smtp_password"
+ACCESS_TOKEN = "test-token"
+
+CONN_ID_DEFAULT = "smtp_default"
+CONN_ID_NONSSL = "smtp_nonssl"
+CONN_ID_SSL_EXTRA = "smtp_ssl_extra"
+CONN_ID_OAUTH = "smtp_oauth2"
+
+DEFAULT_PORT = 465
+NONSSL_PORT = 587
+
+DEFAULT_TIMEOUT = 30
+DEFAULT_RETRY_LIMIT = 5
+
+TEST_SUBJECT = "test subject"
+TEST_BODY = "Test"
+
+SERVER_DISCONNECTED_ERROR = aiosmtplib.errors.SMTPServerDisconnected("Server disconnected")
def _create_fake_smtp(mock_smtplib, use_ssl=True):
@@ -53,110 +85,123 @@ class TestSmtpHook:
def setup_connections(self, create_connection_without_db):
create_connection_without_db(
Connection(
- conn_id="smtp_default",
- conn_type="smtp",
- host="smtp_server_address",
- login="smtp_user",
- password="smtp_password",
- port=465,
- extra=json.dumps(dict(from_email="from")),
+ conn_id=CONN_ID_DEFAULT,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=DEFAULT_PORT,
+ extra=json.dumps(dict(from_email=FROM_EMAIL)),
)
)
create_connection_without_db(
Connection(
- conn_id="smtp_nonssl",
- conn_type="smtp",
- host="smtp_server_address",
- login="smtp_user",
- password="smtp_password",
- port=587,
- extra=json.dumps(dict(disable_ssl=True, from_email="from")),
+ conn_id=CONN_ID_NONSSL,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=NONSSL_PORT,
+ extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)),
)
)
create_connection_without_db(
Connection(
- conn_id="smtp_oauth2",
- conn_type="smtp",
- host="smtp_server_address",
- login="smtp_user",
- password="smtp_password",
- port=587,
- extra=json.dumps(dict(disable_ssl=True, from_email="from", access_token="test-token")),
+ conn_id=CONN_ID_OAUTH,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=NONSSL_PORT,
+ extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL, access_token=ACCESS_TOKEN)),
+ )
+ )
+ create_connection_without_db(
+ Connection(
+ conn_id=CONN_ID_SSL_EXTRA,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=None,
+ password="None",
+ port=DEFAULT_PORT,
+ extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email=FROM_EMAIL)),
)
)
+ @pytest.mark.parametrize(
+ "conn_id, use_ssl, expected_port, create_context",
+ [
+ pytest.param(CONN_ID_DEFAULT, True, DEFAULT_PORT, True, id="ssl-connection"),
+ pytest.param(CONN_ID_NONSSL, False, NONSSL_PORT, False, id="non-ssl-connection"),
+ ],
+ )
@patch(smtplib_string)
@patch("ssl.create_default_context")
- def test_connect_and_disconnect(self, create_default_context, mock_smtplib):
- mock_conn = _create_fake_smtp(mock_smtplib)
+ def test_connect_and_disconnect(
+ self, create_default_context, mock_smtplib, conn_id, use_ssl, expected_port, create_context
+ ):
+ """Test sync connection with different configurations."""
+ mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=use_ssl)
- with SmtpHook():
+ with SmtpHook(smtp_conn_id=conn_id):
pass
- assert create_default_context.called
- mock_smtplib.SMTP_SSL.assert_called_once_with(
- host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value
- )
- mock_conn.login.assert_called_once_with("smtp_user", "smtp_password")
- assert mock_conn.close.call_count == 1
- @patch(smtplib_string)
- def test_connect_and_disconnect_via_nonssl(self, mock_smtplib):
- mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False)
-
- with SmtpHook(smtp_conn_id="smtp_nonssl"):
- pass
+ if create_context:
+ assert create_default_context.called
+ mock_smtplib.SMTP_SSL.assert_called_once_with(
+ host=SMTP_HOST,
+ port=expected_port,
+ timeout=DEFAULT_TIMEOUT,
+ context=create_default_context.return_value,
+ )
+ else:
+ mock_smtplib.SMTP.assert_called_once_with(
+ host=SMTP_HOST,
+ port=expected_port,
+ timeout=DEFAULT_TIMEOUT,
+ )
- mock_smtplib.SMTP.assert_called_once_with(host="smtp_server_address", port=587, timeout=30)
- mock_conn.login.assert_called_once_with("smtp_user", "smtp_password")
+ mock_conn.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD)
assert mock_conn.close.call_count == 1
@patch(smtplib_string)
def test_get_email_address_single_email(self, mock_smtplib):
with SmtpHook() as smtp_hook:
- assert smtp_hook._get_email_address_list("test1@example.com") == ["test1@example.com"]
-
- @patch(smtplib_string)
- def test_get_email_address_comma_sep_string(self, mock_smtplib):
- with SmtpHook() as smtp_hook:
- assert smtp_hook._get_email_address_list("test1@example.com, test2@example.com") == TEST_EMAILS
-
- @patch(smtplib_string)
- def test_get_email_address_colon_sep_string(self, mock_smtplib):
- with SmtpHook() as smtp_hook:
- assert smtp_hook._get_email_address_list("test1@example.com; test2@example.com") == TEST_EMAILS
-
- @patch(smtplib_string)
- def test_get_email_address_list(self, mock_smtplib):
- with SmtpHook() as smtp_hook:
- assert (
- smtp_hook._get_email_address_list(["test1@example.com", "test2@example.com"]) == TEST_EMAILS
- )
-
+ assert smtp_hook._get_email_address_list(FROM_EMAIL) == [FROM_EMAIL]
+
+ @pytest.mark.parametrize(
+ "email_input",
+ [
+ pytest.param(f"{FROM_EMAIL}, {TO_EMAIL}", id="comma_separated"),
+ pytest.param(f"{FROM_EMAIL}; {TO_EMAIL}", id="semicolon_separated"),
+ pytest.param([FROM_EMAIL, TO_EMAIL], id="list_input"),
+ pytest.param((FROM_EMAIL, TO_EMAIL), id="tuple_input"),
+ ],
+ )
@patch(smtplib_string)
- def test_get_email_address_tuple(self, mock_smtplib):
+ def test_get_email_address_parsing(self, mock_smtplib, email_input):
with SmtpHook() as smtp_hook:
- assert (
- smtp_hook._get_email_address_list(("test1@example.com", "test2@example.com")) == TEST_EMAILS
- )
-
- @patch(smtplib_string)
- def test_get_email_address_invalid_type(self, mock_smtplib):
- with pytest.raises(TypeError):
- with SmtpHook() as smtp_hook:
- smtp_hook._get_email_address_list(1)
-
+ assert smtp_hook._get_email_address_list(email_input) == TEST_EMAILS
+
+ @pytest.mark.parametrize(
+ "invalid_input",
+ [
+ pytest.param(1, id="invalid_scalar_type"),
+ pytest.param([FROM_EMAIL, 2], id="invalid_type_in_list"),
+ ],
+ )
@patch(smtplib_string)
- def test_get_email_address_invalid_type_in_iterable(self, mock_smtplib):
+ def test_get_email_address_invalid_types(self, mock_smtplib, invalid_input):
with pytest.raises(TypeError):
with SmtpHook() as smtp_hook:
- smtp_hook._get_email_address_list(["test1@example.com", 2])
+ smtp_hook._get_email_address_list(invalid_input)
@patch(smtplib_string)
def test_build_mime_message(self, mock_smtplib):
- mail_from = "from@example.com"
- mail_to = "to@example.com"
- subject = "test subject"
- html_content = "Test"
+ mail_from = FROM_EMAIL
+ mail_to = TO_EMAIL
+ subject = TEST_SUBJECT
+ html_content = TEST_BODY
custom_headers = {"Reply-To": "reply_to@example.com"}
with SmtpHook() as smtp_hook:
msg, recipients = smtp_hook._build_mime_message(
@@ -181,15 +226,15 @@ def test_send_smtp(self, mock_smtplib):
attachment.write(b"attachment")
attachment.seek(0)
smtp_hook.send_email_smtp(
- to="to", subject="subject", html_content="content", files=[attachment.name]
+ to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, files=[attachment.name]
)
assert mock_send_mime.called
_, call_args = mock_send_mime.call_args
- assert call_args["from_addr"] == "from"
- assert call_args["to_addrs"] == ["to"]
+ assert call_args["from_addr"] == FROM_EMAIL
+ assert call_args["to_addrs"] == [TO_EMAIL]
msg = call_args["msg"]
- assert "Subject: subject" in msg
- assert "From: from" in msg
+ assert f"Subject: {TEST_SUBJECT}" in msg
+ assert f"From: {FROM_EMAIL}" in msg
filename = 'attachment; filename="' + os.path.basename(attachment.name) + '"'
assert filename in msg
mimeapp = MIMEApplication("attachment")
@@ -199,148 +244,150 @@ def test_send_smtp(self, mock_smtplib):
@patch(smtplib_string)
def test_hook_conn(self, mock_smtplib, mock_hook_conn):
mock_conn = Mock()
- mock_conn.login = "user"
- mock_conn.password = "password"
- mock_conn.extra_dejson = {
- "disable_ssl": False,
- }
+ mock_conn.login = SMTP_LOGIN
+ mock_conn.password = SMTP_PASSWORD
+ mock_conn.extra_dejson = {"disable_ssl": False}
mock_hook_conn.return_value = mock_conn
smtp_client_mock = mock_smtplib.SMTP_SSL()
with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
- mock_hook_conn.assert_called_with("smtp_default")
- smtp_client_mock.login.assert_called_once_with("user", "password")
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ from_email=FROM_EMAIL,
+ )
+
+ mock_hook_conn.assert_called_with(CONN_ID_DEFAULT)
+ smtp_client_mock.login.assert_called_once_with(SMTP_LOGIN, SMTP_PASSWORD)
smtp_client_mock.sendmail.assert_called_once()
assert smtp_client_mock.close.called
+ @pytest.mark.parametrize(
+ "conn_id, ssl_context, create_context_called, use_default_context",
+ [
+ pytest.param(CONN_ID_DEFAULT, "default", True, True, id="default_context"),
+ pytest.param(CONN_ID_SSL_EXTRA, "none", False, False, id="none_context"),
+ pytest.param(CONN_ID_DEFAULT, "default", True, True, id="explicit_default_context"),
+ ],
+ )
@patch("smtplib.SMTP_SSL")
@patch("smtplib.SMTP")
@patch("ssl.create_default_context")
- def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl):
- mock_smtp_ssl.return_value = Mock()
- with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
- assert not mock_smtp.called
- assert create_default_context.called
- mock_smtp_ssl.assert_called_once_with(
- host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value
- )
-
- @patch("smtplib.SMTP_SSL")
- @patch("smtplib.SMTP")
- @patch("ssl.create_default_context")
- def test_send_mime_ssl_extra_none_context(
- self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db
+ def test_send_mime_ssl_context(
+ self,
+ create_default_context,
+ mock_smtp,
+ mock_smtp_ssl,
+ conn_id,
+ ssl_context,
+ create_context_called,
+ use_default_context,
):
mock_smtp_ssl.return_value = Mock()
- conn = Connection(
- conn_id="smtp_ssl_extra",
- conn_type="smtp",
- host="smtp_server_address",
- login=None,
- password="None",
- port=465,
- extra=json.dumps(dict(use_ssl=True, ssl_context="none", from_email="from")),
- )
- create_connection_without_db(conn)
- with SmtpHook(smtp_conn_id="smtp_ssl_extra") as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
- assert not mock_smtp.called
- create_default_context.assert_not_called()
- mock_smtp_ssl.assert_called_once_with(host="smtp_server_address", port=465, timeout=30, context=None)
- @patch("smtplib.SMTP_SSL")
- @patch("smtplib.SMTP")
- @patch("ssl.create_default_context")
- def test_send_mime_ssl_extra_default_context(
- self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db
- ):
- mock_smtp_ssl.return_value = Mock()
- conn = Connection(
- conn_id="smtp_ssl_extra",
- conn_type="smtp",
- host="smtp_server_address",
- login=None,
- password="None",
- port=465,
- extra=json.dumps(dict(use_ssl=True, ssl_context="default", from_email="from")),
- )
- create_connection_without_db(conn)
- with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
+ with SmtpHook(conn_id) as smtp_hook:
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL
+ )
+
assert not mock_smtp.called
- assert create_default_context.called
+ if use_default_context:
+ assert create_default_context.called
+ expected_context = create_default_context.return_value
+ else:
+ create_default_context.assert_not_called()
+ expected_context = None
+
mock_smtp_ssl.assert_called_once_with(
- host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value
+ host=SMTP_HOST, port=DEFAULT_PORT, timeout=DEFAULT_TIMEOUT, context=expected_context
)
@patch("smtplib.SMTP_SSL")
@patch("smtplib.SMTP")
@patch("ssl.create_default_context")
- def test_send_mime_default_context(
- self, create_default_context, mock_smtp, mock_smtp_ssl, create_connection_without_db
- ):
+ def test_send_mime_ssl(self, create_default_context, mock_smtp, mock_smtp_ssl):
mock_smtp_ssl.return_value = Mock()
- conn = Connection(
- conn_id="smtp_ssl_extra",
- conn_type="smtp",
- host="smtp_server_address",
- login=None,
- password="None",
- port=465,
- extra=json.dumps(dict(use_ssl=True, from_email="from")),
- )
- create_connection_without_db(conn)
+
with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ from_email=FROM_EMAIL,
+ )
+
assert not mock_smtp.called
assert create_default_context.called
mock_smtp_ssl.assert_called_once_with(
- host="smtp_server_address", port=465, timeout=30, context=create_default_context.return_value
+ host=SMTP_HOST,
+ port=DEFAULT_PORT,
+ timeout=DEFAULT_TIMEOUT,
+ context=create_default_context.return_value,
)
@patch("smtplib.SMTP_SSL")
@patch("smtplib.SMTP")
def test_send_mime_nossl(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = Mock()
- with SmtpHook(smtp_conn_id="smtp_nonssl") as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
+
+ with SmtpHook(smtp_conn_id=CONN_ID_NONSSL) as smtp_hook:
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY, from_email=FROM_EMAIL
+ )
+
assert not mock_smtp_ssl.called
- mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30)
+ mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT)
@patch("smtplib.SMTP")
def test_send_mime_noauth(self, mock_smtp, create_connection_without_db):
mock_smtp.return_value = Mock()
conn = Connection(
conn_id="smtp_noauth",
- conn_type="smtp",
- host="smtp_server_address",
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
login=None,
password="None",
- port=587,
- extra=json.dumps(dict(disable_ssl=True, from_email="from")),
+ port=NONSSL_PORT,
+ extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)),
)
create_connection_without_db(conn)
with SmtpHook(smtp_conn_id="smtp_noauth") as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", from_email="from")
- mock_smtp.assert_called_once_with(host="smtp_server_address", port=587, timeout=30)
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ from_email=FROM_EMAIL,
+ )
+ mock_smtp.assert_called_once_with(host=SMTP_HOST, port=NONSSL_PORT, timeout=DEFAULT_TIMEOUT)
assert not mock_smtp.login.called
@patch("smtplib.SMTP_SSL")
@patch("smtplib.SMTP")
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content", dryrun=True)
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ dryrun=True,
+ )
+
assert not mock_smtp.sendmail.called
assert not mock_smtp_ssl.sendmail.called
@patch("smtplib.SMTP_SSL")
def test_send_mime_ssl_complete_failure(self, mock_smtp_ssl):
mock_smtp_ssl().sendmail.side_effect = smtplib.SMTPServerDisconnected()
+
with SmtpHook() as smtp_hook:
with pytest.raises(smtplib.SMTPServerDisconnected):
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content")
- assert mock_smtp_ssl().sendmail.call_count == 5
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ )
+
+ assert mock_smtp_ssl().sendmail.call_count == DEFAULT_RETRY_LIMIT
@patch("email.message.Message.as_string")
@patch("smtplib.SMTP_SSL")
@@ -350,10 +397,14 @@ def test_send_mime_partial_failure(self, mock_smtp_ssl, mime_message_mock):
side_effects = [smtplib.SMTPServerDisconnected(), smtplib.SMTPServerDisconnected(), final_mock]
mock_smtp_ssl.side_effect = side_effects
with SmtpHook() as smtp_hook:
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content")
+ smtp_hook.send_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ )
assert mock_smtp_ssl.call_count == side_effects.index(final_mock) + 1
assert final_mock.starttls.called
- final_mock.sendmail.assert_called_once_with(from_addr="from", to_addrs=["to"], msg="msg")
+ final_mock.sendmail.assert_called_once_with(from_addr=FROM_EMAIL, to_addrs=[TO_EMAIL], msg="msg")
assert final_mock.close.called
@patch("smtplib.SMTP_SSL")
@@ -366,18 +417,20 @@ def test_send_mime_custom_timeout_retrylimit(
custom_timeout = 60
fake_conn = Connection(
conn_id="mock_conn",
- conn_type="smtp",
- host="smtp_server_address",
- login="smtp_user",
- password="smtp_password",
- port=465,
- extra=json.dumps(dict(from_email="from", timeout=custom_timeout, retry_limit=custom_retry_limit)),
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=DEFAULT_PORT,
+ extra=json.dumps(
+ dict(from_email=FROM_EMAIL, timeout=custom_timeout, retry_limit=custom_retry_limit)
+ ),
)
create_connection_without_db(fake_conn)
with SmtpHook(smtp_conn_id="mock_conn") as smtp_hook:
with pytest.raises(smtplib.SMTPServerDisconnected):
- smtp_hook.send_email_smtp(to="to", subject="subject", html_content="content")
+ smtp_hook.send_email_smtp(to=TO_EMAIL, subject=TEST_SUBJECT, html_content=TEST_BODY)
expected_call = call(
host=fake_conn.host,
@@ -393,18 +446,18 @@ def test_send_mime_custom_timeout_retrylimit(
def test_oauth2_auth_called(self, mock_smtplib):
mock_conn = _create_fake_smtp(mock_smtplib, use_ssl=False)
- with SmtpHook(smtp_conn_id="smtp_oauth2", auth_type="oauth2") as smtp_hook:
+ with SmtpHook(smtp_conn_id=CONN_ID_OAUTH, auth_type="oauth2") as smtp_hook:
smtp_hook.send_email_smtp(
- to="to@example.com",
- subject="subject",
- html_content="content",
- from_email="from",
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ from_email=FROM_EMAIL,
)
assert mock_conn.auth.called
args, _ = mock_conn.auth.call_args
assert args[0] == "XOAUTH2"
- assert build_xoauth2_string("smtp_user", "test-token") == args[1]()
+ assert build_xoauth2_string(SMTP_LOGIN, ACCESS_TOKEN) == args[1]()
@patch(smtplib_string)
def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_without_db):
@@ -413,22 +466,187 @@ def test_oauth2_missing_token_raises(self, mock_smtplib, create_connection_witho
create_connection_without_db(
Connection(
conn_id="smtp_oauth2_empty",
- conn_type="smtp",
- host="smtp_server_address",
- login="smtp_user",
- password="smtp_password",
- port=587,
- extra=json.dumps(dict(disable_ssl=True, from_email="from")),
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=NONSSL_PORT,
+ extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)),
)
)
with pytest.raises(AirflowException):
with SmtpHook(smtp_conn_id="smtp_oauth2_empty", auth_type="oauth2") as h:
h.send_email_smtp(
- to="to@example.com",
- subject="subject",
- html_content="content",
- from_email="from",
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ from_email=FROM_EMAIL,
)
assert not mock_conn.auth.called
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0")
+class TestSmtpHookAsync:
+ """Tests for async functionality in SmtpHook."""
+
+ @pytest.fixture(autouse=True)
+ def setup_connections(self, create_connection_without_db):
+ create_connection_without_db(
+ Connection(
+ conn_id=CONN_ID_DEFAULT,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=DEFAULT_PORT,
+ extra=json.dumps(dict(from_email=FROM_EMAIL, ssl_context="default")),
+ )
+ )
+ create_connection_without_db(
+ Connection(
+ conn_id=CONN_ID_NONSSL,
+ conn_type=CONN_TYPE,
+ host=SMTP_HOST,
+ login=SMTP_LOGIN,
+ password=SMTP_PASSWORD,
+ port=NONSSL_PORT,
+ extra=json.dumps(dict(disable_ssl=True, from_email=FROM_EMAIL)),
+ )
+ )
+
+ @pytest.fixture
+ def mock_smtp_client(self):
+ """Create a mock SMTP client with async capabilities."""
+ mock_client = AsyncMock(spec=aiosmtplib.SMTP)
+ mock_client.starttls = AsyncMock()
+ mock_client.auth_login = AsyncMock()
+ mock_client.sendmail = AsyncMock()
+ mock_client.quit = AsyncMock()
+ return mock_client
+
+ @pytest.fixture
+ def mock_smtp(self, mock_smtp_client):
+ """Set up the SMTP mock with context manager."""
+ with mock.patch("airflow.providers.smtp.hooks.smtp.aiosmtplib.SMTP") as mock_smtp:
+ mock_smtp.return_value = mock_smtp_client
+ yield mock_smtp
+
+ @pytest.fixture
+ def mock_get_connection(self):
+ """Mock the async connection retrieval."""
+ with mock.patch("airflow.sdk.bases.hook.BaseHook.aget_connection") as mock_conn:
+
+ async def async_get_connection(conn_id):
+ from airflow.sdk.definitions.connection import Connection
+
+ return Connection.from_json(os.environ[f"AIRFLOW_CONN_{conn_id.upper()}"])
+
+ mock_conn.side_effect = async_get_connection
+ yield mock_conn
+
+ @staticmethod
+ def _create_fake_async_smtp(mock_smtp):
+ mock_client = AsyncMock(spec=aiosmtplib.SMTP)
+ mock_client.starttls = AsyncMock()
+ mock_client.auth_login = AsyncMock()
+ mock_client.sendmail = AsyncMock()
+ mock_client.quit = AsyncMock()
+ mock_smtp.return_value = mock_client
+ return mock_client
+
+ @pytest.mark.parametrize(
+ "conn_id, expected_port, expected_ssl",
+ [
+ pytest.param(CONN_ID_NONSSL, NONSSL_PORT, False, id="non-ssl-connection"),
+ pytest.param(CONN_ID_DEFAULT, DEFAULT_PORT, True, id="ssl-connection"),
+ ],
+ )
+ async def test_async_connection(
+ self, mock_smtp, mock_smtp_client, mock_get_connection, conn_id, expected_port, expected_ssl
+ ):
+ """Test async connection with different configurations."""
+ async with SmtpHook(smtp_conn_id=conn_id) as hook:
+ assert hook is not None
+
+ mock_smtp.assert_called_once_with(
+ hostname=SMTP_HOST,
+ port=expected_port,
+ timeout=DEFAULT_TIMEOUT,
+ use_tls=expected_ssl,
+ start_tls=None if expected_ssl else True,
+ )
+
+ if expected_ssl:
+ assert mock_smtp_client.starttls.await_count == 1
+
+ assert mock_smtp_client.auth_login.await_count == 1
+ mock_smtp_client.auth_login.assert_awaited_once_with(SMTP_LOGIN, SMTP_PASSWORD)
+
+ @pytest.mark.asyncio
+ async def test_async_send_email(self, mock_smtp, mock_smtp_client, mock_get_connection):
+ """Test async email sending functionality."""
+ async with SmtpHook() as hook:
+ await hook.asend_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ )
+
+ assert mock_smtp_client.sendmail.called
+ # The async version of sendmail only supports positional arguments
+ # for some reason, so we have to check these by positional args
+ call_args = mock_smtp_client.sendmail.await_args.args
+ assert call_args[0] == FROM_EMAIL # sender is first positional arg
+ assert call_args[1] == [TO_EMAIL] # recipients is the second positional arg
+ assert f"Subject: {TEST_SUBJECT}" in call_args[2] # message is the third positional arg
+
+ @pytest.mark.parametrize(
+ "side_effect, expected_calls, should_raise",
+ [
+ pytest.param(
+ [SERVER_DISCONNECTED_ERROR, SERVER_DISCONNECTED_ERROR, None],
+ 3,
+ False,
+ id="success_after_retries",
+ ),
+ pytest.param(SERVER_DISCONNECTED_ERROR, DEFAULT_RETRY_LIMIT, True, id="max_retries_exceeded"),
+ ],
+ )
+ @pytest.mark.asyncio
+ async def test_async_send_email_retries(
+ self, mock_smtp, mock_smtp_client, mock_get_connection, side_effect, expected_calls, should_raise
+ ):
+ mock_smtp_client.sendmail.side_effect = side_effect
+
+ if should_raise:
+ with pytest.raises(aiosmtplib.errors.SMTPServerDisconnected):
+ async with SmtpHook() as hook:
+ await hook.asend_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ )
+ else:
+ async with SmtpHook() as hook:
+ await hook.asend_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ )
+
+ assert mock_smtp_client.sendmail.call_count == expected_calls
+
+ async def test_async_send_email_dryrun(self, mock_smtp, mock_smtp_client, mock_get_connection):
+ """Test async email sending in dryrun mode."""
+ async with SmtpHook() as hook:
+ await hook.asend_email_smtp(
+ to=TO_EMAIL,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ dryrun=True,
+ )
+
+ mock_smtp_client.sendmail.assert_not_awaited()
diff --git a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py
index 4dfffab36e03e..6f8fcf60ac848 100644
--- a/providers/smtp/tests/unit/smtp/notifications/test_smtp.py
+++ b/providers/smtp/tests/unit/smtp/notifications/test_smtp.py
@@ -18,126 +18,157 @@
from __future__ import annotations
import tempfile
+from dataclasses import dataclass
from unittest import mock
+from unittest.mock import AsyncMock
-from airflow.providers.smtp.hooks.smtp import SmtpHook
-from airflow.providers.smtp.notifications.smtp import (
- SmtpNotifier,
- send_smtp_notification,
-)
+import pytest
-SMTP_API_DEFAULT_CONN_ID = SmtpHook.default_conn_name
+from airflow.providers.smtp.notifications.smtp import SmtpNotifier, send_smtp_notification
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS
-NUM_TRY = 0
+TRY_NUMBER = 0
+
+SMTP_CONN_ID = "smtp_default"
+SMTP_AUTH_TYPE = "oauth2"
+
+# Standard settings
+DEFAULT_EMAIL_PARAMS = {
+ "mime_subtype": "mixed",
+ "mime_charset": "utf-8",
+ "files": None,
+ "cc": None,
+ "bcc": None,
+ "custom_headers": None,
+}
+
+# DAG settings
+TEST_DAG_ID = "test_dag"
+TEST_TASK_ID = "test_task"
+TEST_TASK_STATE = None
+TEST_RUN_ID = "test_run"
+
+# Jinja template patterns
+DAG_ID_TEMPLATE_STRING = "{{dag.dag_id}}"
+TI_TEMPLATE_STRING = "{{ti.task_id}}"
+
+SENDER_EMAIL_SUFFIX = "sender@test.com"
+RECEIVER_EMAIL_SUFFIX = "receiver@test.com"
+# Base test values
+TEST_SENDER = f"test_{SENDER_EMAIL_SUFFIX}"
+TEST_RECEIVER = f"test_{RECEIVER_EMAIL_SUFFIX}"
+TEST_SUBJECT = "subject"
+TEST_BODY = "body"
+
+NOTIFIER_DEFAULT_PARAMS = {
+ "from_email": TEST_SENDER,
+ "to": TEST_RECEIVER,
+ "subject": TEST_SUBJECT,
+ "html_content": TEST_BODY,
+}
+
+
+# Templated versions
+@dataclass(frozen=True)
+class TemplatedString:
+ template: str
+
+ @property
+ def rendered(self) -> str:
+ return self.template.replace(DAG_ID_TEMPLATE_STRING, TEST_DAG_ID).replace(
+ TI_TEMPLATE_STRING, TEST_TASK_ID
+ )
+
+
+# DAG-based templates
+TEMPLATED_SENDER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}")
+TEMPLATED_RECEIVER = TemplatedString(f"{DAG_ID_TEMPLATE_STRING}_{RECEIVER_EMAIL_SUFFIX}")
+TEMPLATED_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {DAG_ID_TEMPLATE_STRING}")
+TEMPLATED_BODY = TemplatedString(f"{TEST_BODY} {DAG_ID_TEMPLATE_STRING}")
+
+# Task-based templates
+TEMPLATED_TI_SUBJECT = TemplatedString(f"{TEST_SUBJECT} {TI_TEMPLATE_STRING}")
+TEMPLATED_TI_SENDER = TemplatedString(f"{TI_TEMPLATE_STRING}_{SENDER_EMAIL_SUFFIX}")
class TestSmtpNotifier:
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier(_self, mock_smtphook_hook, create_dag_without_db):
- notifier = send_smtp_notification(
- from_email="test_sender@test.com",
- to="test_reciver@test.com",
- subject="subject",
- html_content="body",
- )
- notifier({"dag": create_dag_without_db("test_notifier")})
+ notifier = send_smtp_notification(**NOTIFIER_DEFAULT_PARAMS)
+ notifier({"dag": create_dag_without_db(TEST_DAG_ID)})
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
- from_email="test_sender@test.com",
- to="test_reciver@test.com",
- subject="subject",
- html_content="body",
- smtp_conn_id="smtp_default",
- files=None,
- cc=None,
- bcc=None,
- mime_subtype="mixed",
- mime_charset="utf-8",
- custom_headers=None,
+ **NOTIFIER_DEFAULT_PARAMS,
+ smtp_conn_id=SMTP_CONN_ID,
+ **DEFAULT_EMAIL_PARAMS,
)
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_with_notifier_class(self, mock_smtphook_hook, create_dag_without_db):
- notifier = SmtpNotifier(
- from_email="test_sender@test.com",
- to="test_reciver@test.com",
- subject="subject",
- html_content="body",
- )
- notifier({"dag": create_dag_without_db("test_notifier")})
+ notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS)
+ notifier({"dag": create_dag_without_db(TEST_DAG_ID)})
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
- from_email="test_sender@test.com",
- to="test_reciver@test.com",
- subject="subject",
- html_content="body",
- smtp_conn_id="smtp_default",
- files=None,
- cc=None,
- bcc=None,
- mime_subtype="mixed",
- mime_charset="utf-8",
- custom_headers=None,
+ from_email=TEST_SENDER,
+ to=TEST_RECEIVER,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ smtp_conn_id=SMTP_CONN_ID,
+ **DEFAULT_EMAIL_PARAMS,
)
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_templated(self, mock_smtphook_hook, create_dag_without_db):
notifier = SmtpNotifier(
- from_email="test_sender@test.com {{dag.dag_id}}",
- to="test_reciver@test.com {{dag.dag_id}}",
- subject="subject {{dag.dag_id}}",
- html_content="body {{dag.dag_id}}",
+ from_email=TEMPLATED_SENDER.template,
+ to=TEMPLATED_RECEIVER.template,
+ subject=TEMPLATED_SUBJECT.template,
+ html_content=TEMPLATED_BODY.template,
)
- context = {"dag": create_dag_without_db("test_notifier")}
+ context = {"dag": create_dag_without_db(TEST_DAG_ID)}
+
notifier(context)
+
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
- from_email="test_sender@test.com test_notifier",
- to="test_reciver@test.com test_notifier",
- subject="subject test_notifier",
- html_content="body test_notifier",
- smtp_conn_id="smtp_default",
- files=None,
- cc=None,
- bcc=None,
- mime_subtype="mixed",
- mime_charset="utf-8",
- custom_headers=None,
+ from_email=TEMPLATED_SENDER.rendered,
+ to=TEMPLATED_RECEIVER.rendered,
+ subject=TEMPLATED_SUBJECT.rendered,
+ html_content=TEMPLATED_BODY.rendered,
+ smtp_conn_id=SMTP_CONN_ID,
+ **DEFAULT_EMAIL_PARAMS,
)
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_with_defaults(self, mock_smtphook_hook, create_dag_without_db, mock_task_instance):
# TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+
mock_ti = mock_task_instance(
- dag_id="test_dag",
- task_id="op",
- run_id="test",
- try_number=NUM_TRY,
+ dag_id=TEST_DAG_ID,
+ task_id=TEST_TASK_ID,
+ run_id=TEST_RUN_ID,
+ try_number=TRY_NUMBER,
max_tries=0,
- state=None,
+ state=TEST_TASK_STATE,
)
- context = {"dag": create_dag_without_db("test_dag"), "ti": mock_ti}
+ context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": mock_ti}
notifier = SmtpNotifier(
- from_email="any email",
- to="test_reciver@test.com",
+ from_email=TEST_SENDER,
+ to=TEST_RECEIVER,
)
mock_smtphook_hook.return_value.__enter__.return_value.subject_template = None
mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = None
+
notifier(context)
+
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
- from_email="any email",
- to="test_reciver@test.com",
- subject="DAG test_dag - Task op - Run ID test in State None",
+ from_email=TEST_SENDER,
+ to=TEST_RECEIVER,
+ subject=f"DAG {TEST_DAG_ID} - Task {TEST_TASK_ID} - Run ID {TEST_RUN_ID} in State {TEST_TASK_STATE}",
html_content=mock.ANY,
- smtp_conn_id="smtp_default",
- files=None,
- cc=None,
- bcc=None,
- mime_subtype="mixed",
- mime_charset="utf-8",
- custom_headers=None,
+ smtp_conn_id=SMTP_CONN_ID,
+ **DEFAULT_EMAIL_PARAMS,
)
content = mock_smtphook_hook.return_value.__enter__().send_email_smtp.call_args.kwargs["html_content"]
- assert f"{NUM_TRY} of 1" in content
+ assert f"{TRY_NUMBER} of 1" in content
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_with_nondefault_connection_extra(
@@ -145,58 +176,159 @@ def test_notifier_with_nondefault_connection_extra(
):
# TODO: we can use create_runtime_ti fixture in place of mock_task_instance once provider has minimum AF to Airflow 3.0+
ti = mock_task_instance(
- dag_id="test_dag",
- task_id="op",
- run_id="test_run",
- try_number=NUM_TRY,
+ dag_id=TEST_DAG_ID,
+ task_id=TEST_TASK_ID,
+ run_id=TEST_RUN_ID,
+ try_number=TRY_NUMBER,
max_tries=0,
- state=None,
+ state=TEST_TASK_STATE,
)
- context = {"dag": create_dag_without_db("test_dag"), "ti": ti}
+ context = {"dag": create_dag_without_db(TEST_DAG_ID), "ti": ti}
with (
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject,
tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content,
):
- f_subject.write("Task {{ ti.task_id }} failed")
+ f_subject.write(TEMPLATED_TI_SUBJECT.template)
f_subject.flush()
- f_content.write("Mock content goes here")
+ f_content.write(TEST_BODY)
f_content.flush()
- mock_smtphook_hook.return_value.__enter__.return_value.from_email = "{{ ti.task_id }}@test.com"
+ mock_smtphook_hook.return_value.__enter__.return_value.from_email = TEMPLATED_TI_SENDER.template
mock_smtphook_hook.return_value.__enter__.return_value.subject_template = f_subject.name
mock_smtphook_hook.return_value.__enter__.return_value.html_content_template = f_content.name
- notifier = SmtpNotifier(
- to="test_reciver@test.com",
- )
+
+ notifier = SmtpNotifier(to=TEST_RECEIVER)
notifier(context)
+
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
- from_email="op@test.com",
- to="test_reciver@test.com",
- subject="Task op failed",
- html_content="Mock content goes here",
- smtp_conn_id="smtp_default",
- files=None,
- cc=None,
- bcc=None,
- mime_subtype="mixed",
- mime_charset="utf-8",
- custom_headers=None,
+ from_email=TEMPLATED_TI_SENDER.rendered,
+ to=TEST_RECEIVER,
+ subject=TEMPLATED_TI_SUBJECT.rendered,
+ html_content=TEST_BODY,
+ smtp_conn_id=SMTP_CONN_ID,
+ **DEFAULT_EMAIL_PARAMS,
)
@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_without_db):
+ notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS, auth_type=SMTP_AUTH_TYPE)
+
+ notifier({"dag": create_dag_without_db(TEST_DAG_ID)})
+
+ mock_smtphook_hook.assert_called_once_with(smtp_conn_id=SMTP_CONN_ID, auth_type=SMTP_AUTH_TYPE)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skipif(not AIRFLOW_V_3_1_PLUS, reason="Async support was added to BaseNotifier in 3.1.0")
+class TestSmtpNotifierAsync:
+ @pytest.fixture
+ def mock_smtp_client(self):
+ """Create a mock SMTP object with async capabilities."""
+ mock_smtp = AsyncMock()
+ mock_smtp.asend_email_smtp = AsyncMock()
+ return mock_smtp
+
+ @pytest.fixture
+ def mock_smtp_hook(self, mock_smtp_client):
+ """Set up the SMTP hook with async context manager."""
+ with mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook") as mock_hook:
+ mock_hook.return_value.__aenter__ = AsyncMock(return_value=mock_smtp_client)
+ yield mock_hook
+
+ @pytest.mark.asyncio
+ async def test_async_notifier(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db):
notifier = SmtpNotifier(
- from_email="test_sender@test.com",
- to="test_reciver@test.com",
- auth_type="oauth2",
- subject="subject",
- html_content="body",
+ **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)}
)
+ await notifier.async_notify({"dag": create_dag_without_db(TEST_DAG_ID)})
- notifier({"dag": create_dag_without_db("test_notifier")})
+ mock_smtp_client.asend_email_smtp.assert_called_once_with(
+ smtp_conn_id=SMTP_CONN_ID,
+ **NOTIFIER_DEFAULT_PARAMS,
+ **DEFAULT_EMAIL_PARAMS,
+ )
- mock_smtphook_hook.assert_called_once_with(
- smtp_conn_id="smtp_default",
- auth_type="oauth2",
+ async def test_async_notifier_with_notifier_class(
+ self, mock_smtp_hook, mock_smtp_client, create_dag_without_db
+ ):
+ notifier = SmtpNotifier(
+ **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)}
)
+
+ await notifier
+
+ mock_smtp_client.asend_email_smtp.assert_called_once_with(
+ smtp_conn_id=SMTP_CONN_ID,
+ **NOTIFIER_DEFAULT_PARAMS,
+ **DEFAULT_EMAIL_PARAMS,
+ )
+
+ async def test_async_notifier_templated(self, mock_smtp_hook, mock_smtp_client, create_dag_without_db):
+ notifier = SmtpNotifier(
+ from_email=TEMPLATED_SENDER.template,
+ to=TEMPLATED_RECEIVER.template,
+ subject=TEMPLATED_SUBJECT.template,
+ html_content=TEMPLATED_BODY.template,
+ context={"dag": create_dag_without_db(TEST_DAG_ID)},
+ )
+
+ await notifier
+
+ mock_smtp_client.asend_email_smtp.assert_called_once_with(
+ smtp_conn_id=SMTP_CONN_ID,
+ from_email=TEMPLATED_SENDER.rendered,
+ to=TEMPLATED_RECEIVER.rendered,
+ subject=TEMPLATED_SUBJECT.rendered,
+ html_content=TEMPLATED_BODY.rendered,
+ **DEFAULT_EMAIL_PARAMS,
+ )
+
+ async def test_async_notifier_with_defaults(
+ self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance
+ ):
+ mock_smtp_client.subject_template = None
+ mock_smtp_client.html_content_template = None
+ mock_smtp_client.from_email = None
+
+ notifier = SmtpNotifier(
+ **NOTIFIER_DEFAULT_PARAMS, context={"dag": create_dag_without_db(TEST_DAG_ID)}
+ )
+
+ await notifier
+
+ mock_smtp_client.asend_email_smtp.assert_called_once_with(
+ smtp_conn_id=SMTP_CONN_ID,
+ **NOTIFIER_DEFAULT_PARAMS,
+ **DEFAULT_EMAIL_PARAMS,
+ )
+
+ async def test_async_notifier_with_nondefault_connection_extra(
+ self, mock_smtp_hook, mock_smtp_client, create_dag_without_db, mock_task_instance
+ ):
+ with (
+ tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_subject,
+ tempfile.NamedTemporaryFile(mode="wt", suffix=".txt") as f_content,
+ ):
+ f_subject.write(TEST_SUBJECT)
+ f_subject.flush()
+
+ f_content.write(TEST_BODY)
+ f_content.flush()
+
+ mock_smtp_client.from_email = TEST_SENDER
+ mock_smtp_client.subject_template = f_subject.name
+ mock_smtp_client.html_content_template = f_content.name
+
+ notifier = SmtpNotifier(to=TEST_RECEIVER, context={"dag": create_dag_without_db(TEST_DAG_ID)})
+
+ await notifier
+
+ mock_smtp_client.asend_email_smtp.assert_called_once_with(
+ smtp_conn_id=SMTP_CONN_ID,
+ from_email=TEST_SENDER,
+ to=TEST_RECEIVER,
+ subject=TEST_SUBJECT,
+ html_content=TEST_BODY,
+ **DEFAULT_EMAIL_PARAMS,
+ )