diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ses.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ses.py index 7c0568c4d03c8..d418e13749e9a 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/ses.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/ses.py @@ -42,6 +42,20 @@ def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "ses" super().__init__(*args, **kwargs) + @staticmethod + def _build_headers( + custom_headers: dict[str, Any] | None, + reply_to: str | None, + return_path: str | None, + ) -> dict[str, Any]: + _custom_headers = custom_headers or {} + if reply_to: + _custom_headers["Reply-To"] = reply_to + if return_path: + _custom_headers["Return-Path"] = return_path + + return _custom_headers + def send_email( self, mail_from: str, @@ -70,23 +84,17 @@ def send_email( :param files: List of paths of files to be attached :param cc: List of email addresses to set as email's CC :param bcc: List of email addresses to set as email's BCC - :param mime_subtype: Can be used to specify the sub-type of the message. Default = mixed + :param mime_subtype: Can be used to specify the subtype of the message. Default = mixed :param mime_charset: Email's charset. Default = UTF-8. :param return_path: The email address to which replies will be sent. By default, replies are sent to the original sender's email address. :param reply_to: The email address to which message bounces and complaints should be sent. "Return-Path" is sometimes called "envelope from", "envelope sender", or "MAIL FROM". :param custom_headers: Additional headers to add to the MIME message. - No validations are run on these values and they should be able to be encoded. + No validations are run on these values, and they should be able to be encoded. :return: Response from Amazon SES service with unique message identifier. """ - ses_client = self.get_conn() - - custom_headers = custom_headers or {} - if reply_to: - custom_headers["Reply-To"] = reply_to - if return_path: - custom_headers["Return-Path"] = return_path + custom_headers = self._build_headers(custom_headers, reply_to, return_path) message, recipients = build_mime_message( mail_from=mail_from, @@ -101,6 +109,64 @@ def send_email( custom_headers=custom_headers, ) - return ses_client.send_raw_email( + return self.conn.send_raw_email( Source=mail_from, Destinations=recipients, RawMessage={"Data": message.as_string()} ) + + async def asend_email( + self, + mail_from: str, + to: str | Iterable[str], + subject: str, + html_content: str, + files: list[str] | None = None, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + reply_to: str | None = None, + return_path: str | None = None, + custom_headers: dict[str, Any] | None = None, + ) -> dict: + """ + Send email using Amazon Simple Email Service (async). + + .. seealso:: + - :external+boto3:py:meth:`SES.Client.send_raw_email` + + :param mail_from: Email address to set as email's from + :param to: List of email addresses to set as email's to + :param subject: Email's subject + :param html_content: Content of email in HTML format + :param files: List of paths of files to be attached + :param cc: List of email addresses to set as email's CC + :param bcc: List of email addresses to set as email's BCC + :param mime_subtype: Can be used to specify the subtype of the message. Default = mixed + :param mime_charset: Email's charset. Default = UTF-8. + :param return_path: The email address to which replies will be sent. By default, replies + are sent to the original sender's email address. + :param reply_to: The email address to which message bounces and complaints should be sent. + "Return-Path" is sometimes called "envelope from", "envelope sender", or "MAIL FROM". + :param custom_headers: Additional headers to add to the MIME message. + No validations are run on these values, and they should be able to be encoded. + :return: Response from Amazon SES service with unique message identifier. + """ + custom_headers = self._build_headers(custom_headers, reply_to, return_path) + + message, recipients = build_mime_message( + mail_from=mail_from, + to=to, + subject=subject, + html_content=html_content, + files=files, + cc=cc, + bcc=bcc, + mime_subtype=mime_subtype, + mime_charset=mime_charset, + custom_headers=custom_headers, + ) + + async with await self.get_async_conn() as async_client: + return await async_client.send_raw_email( + Source=mail_from, Destinations=recipients, RawMessage={"Data": message.as_string()} + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/notifications/ses.py b/providers/amazon/src/airflow/providers/amazon/aws/notifications/ses.py new file mode 100644 index 0000000000000..9e45a83be56d3 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/notifications/ses.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Iterable, Sequence +from functools import cached_property +from typing import Any + +from airflow.providers.amazon.aws.hooks.ses import SesHook +from airflow.providers.amazon.version_compat import AIRFLOW_V_3_1_PLUS +from airflow.providers.common.compat.notifier import BaseNotifier +from airflow.utils.helpers import prune_dict + + +class SesNotifier(BaseNotifier): + """ + Amazon Simple Email Service (SES) Notifier. + + :param mail_from: Email address to set as email's from + :param to: List of email addresses to set as email's to + :param subject: Email's subject + :param html_content: Content of email in HTML format + :param files: List of paths of files to be attached + :param cc: List of email addresses to set as email's CC + :param bcc: List of email addresses to set as email's BCC + :param mime_subtype: Can be used to specify the subtype of the message. Default = mixed + :param mime_charset: Email's charset. Default = UTF-8. + :param return_path: The email address to which replies will be sent. By default, replies + are sent to the original sender's email address. + :param reply_to: The email address to which message bounces and complaints should be sent. + "Return-Path" is sometimes called "envelope from", "envelope sender", or "MAIL FROM". + :param custom_headers: Additional headers to add to the MIME message. + No validations are run on these values, and they should be able to be encoded. + """ + + template_fields: Sequence[str] = ( + "aws_conn_id", + "region_name", + "mail_from", + "to", + "subject", + "html_content", + "files", + "cc", + "bcc", + "mime_subtype", + "mime_charset", + "reply_to", + "return_path", + "custom_headers", + ) + + def __init__( + self, + *, + aws_conn_id: str | None = SesHook.default_conn_name, + region_name: str | None = None, + mail_from: str, + to: str | Iterable[str], + subject: str, + html_content: str, + files: list[str] | None = None, + cc: str | Iterable[str] | None = None, + bcc: str | Iterable[str] | None = None, + mime_subtype: str = "mixed", + mime_charset: str = "utf-8", + reply_to: str | None = None, + return_path: str | None = None, + custom_headers: dict[str, Any] | None = None, + **kwargs, + ): + if AIRFLOW_V_3_1_PLUS: + # Support for passing context was added in 3.1.0 + super().__init__(**kwargs) + else: + super().__init__() + self.aws_conn_id = aws_conn_id + self.region_name = region_name + + self.mail_from = mail_from + self.to = to + self.subject = subject + self.html_content = html_content + self.files = files + self.cc = cc + self.bcc = bcc + self.mime_subtype = mime_subtype + self.mime_charset = mime_charset + self.reply_to = reply_to + self.return_path = return_path + self.custom_headers = custom_headers + + def _build_send_kwargs(self): + return prune_dict( + { + "mail_from": self.mail_from, + "to": self.to, + "subject": self.subject, + "html_content": self.html_content, + "files": self.files, + "cc": self.cc, + "bcc": self.bcc, + "mime_subtype": self.mime_subtype, + "mime_charset": self.mime_charset, + "reply_to": self.reply_to, + "return_path": self.return_path, + "custom_headers": self.custom_headers, + } + ) + + @cached_property + def hook(self) -> SesHook: + """Amazon Simple Email Service (SES) Hook (cached).""" + return SesHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + + def notify(self, context): + """Send email using Amazon Simple Email Service (SES).""" + self.hook.send_email(**self._build_send_kwargs()) + + async def async_notify(self, context): + """Send email using Amazon Simple Email Service (SES) (async).""" + await self.hook.asend_email(**self._build_send_kwargs()) + + +send_ses_notification = SesNotifier diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_ses.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_ses.py index 023b1f58fa2df..bd56b679e55bd 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_ses.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_ses.py @@ -16,6 +16,8 @@ # under the License. from __future__ import annotations +from unittest import mock + import boto3 import pytest from moto import mock_aws @@ -25,46 +27,111 @@ boto3.setup_default_session() +TEST_TO_ADDRESSES = [ + pytest.param("to@domain.com", id="to=single_string"), + pytest.param("to1@domain.com,to2@domain.com", id="to=comma_string"), + pytest.param(["to1@domain.com", "to2@domain.com"], id="to=list"), +] + +TEST_CC_ADDRESSES = [ + pytest.param("cc@domain.com", id="cc=single_string"), + pytest.param("cc1@domain.com,cc2@domain.com", id="cc=comma_string"), + pytest.param(["cc1@domain.com", "cc2@domain.com"], id="cc=list"), +] + +TEST_BCC_ADDRESSES = [ + pytest.param("bcc@domain.com", id="bcc=single_string"), + pytest.param("bcc1@domain.com,bcc2@domain.com", id="bcc=comma_string"), + pytest.param(["bcc1@domain.com", "bcc2@domain.com"], id="bcc=list"), +] + +TEST_FROM_ADDRESS = "test_from@domain.com" +TEST_SUBJECT = "subject" +TEST_HTML_CONTENT = "Test" +TEST_REPLY_TO = "reply_to@domain.com" +TEST_RETURN_PATH = "return_path@domain.com" + + @mock_aws -def test_get_conn(): - hook = SesHook(aws_conn_id="aws_default") - assert hook.get_conn() is not None +def _verify_address(address: str) -> None: + """ + Amazon only allows emails from verified addresses. If the address is not verified, + this test will raise `botocore.errorfactory.MessageRejected`. + """ + SesHook().get_conn().verify_email_identity(EmailAddress=address) @mock_aws -@pytest.mark.parametrize( - "to", ["to@domain.com", ["to1@domain.com", "to2@domain.com"], "to1@domain.com,to2@domain.com"] -) -@pytest.mark.parametrize( - "cc", ["cc@domain.com", ["cc1@domain.com", "cc2@domain.com"], "cc1@domain.com,cc2@domain.com"] -) -@pytest.mark.parametrize( - "bcc", ["bcc@domain.com", ["bcc1@domain.com", "bcc2@domain.com"], "bcc1@domain.com,bcc2@domain.com"] -) -def test_send_email(to, cc, bcc): - # Given - hook = SesHook() - ses_client = hook.get_conn() - mail_from = "test_from@domain.com" - - # Amazon only allows to send emails from verified addresses, - # then we need to validate the from address before sending the email, - # otherwise this test would raise a `botocore.errorfactory.MessageRejected` exception - ses_client.verify_email_identity(EmailAddress=mail_from) - - # When - response = hook.send_email( - mail_from=mail_from, - to=to, - subject="subject", - html_content="Test", - cc=cc, - bcc=bcc, - reply_to="reply_to@domain.com", - return_path="return_path@domain.com", - ) - - # Then - assert response is not None - assert isinstance(response, dict) - assert "MessageId" in response +class TestSesHook: + def test_get_conn(self): + hook = SesHook(aws_conn_id="aws_default") + assert hook.get_conn() is not None + + @pytest.mark.parametrize("to", TEST_TO_ADDRESSES) + @pytest.mark.parametrize("cc", TEST_CC_ADDRESSES) + @pytest.mark.parametrize("bcc", TEST_BCC_ADDRESSES) + def test_send_email(self, to, cc, bcc): + _verify_address(TEST_FROM_ADDRESS) + hook = SesHook() + + response = hook.send_email( + mail_from=TEST_FROM_ADDRESS, + to=to, + subject=TEST_SUBJECT, + html_content=TEST_HTML_CONTENT, + cc=cc, + bcc=bcc, + reply_to=TEST_REPLY_TO, + return_path=TEST_RETURN_PATH, + ) + + assert response is not None + assert isinstance(response, dict) + assert "MessageId" in response + + +@pytest.mark.asyncio +class TestAsyncSesHook: + """The mock_aws decorator uses `moto` which does not currently support async SES so we mock it manually.""" + + @pytest.fixture + def mock_async_client(self): + return mock.AsyncMock() + + @pytest.fixture + def mock_get_async_conn(self, mock_async_client): + with mock.patch.object(SesHook, "get_async_conn") as mocked_conn: + mocked_conn.return_value.__aenter__.return_value = mock_async_client + yield mocked_conn + + async def test_get_async_conn(self, mock_get_async_conn, mock_async_client): + hook = SesHook() + async with await hook.get_async_conn() as async_conn: + assert async_conn is mock_async_client + + @pytest.mark.parametrize("to", TEST_TO_ADDRESSES) + @pytest.mark.parametrize("cc", TEST_CC_ADDRESSES) + @pytest.mark.parametrize("bcc", TEST_BCC_ADDRESSES) + async def test_asend_email(self, mock_get_async_conn, mock_async_client, to, cc, bcc): + _verify_address(TEST_FROM_ADDRESS) + hook = SesHook() + + mock_async_client.send_raw_email.return_value = {"MessageId": "test_message_id"} + + response = await hook.asend_email( + mail_from=TEST_FROM_ADDRESS, + to=to, + subject=TEST_SUBJECT, + html_content=TEST_HTML_CONTENT, + cc=cc, + bcc=bcc, + reply_to=TEST_REPLY_TO, + return_path=TEST_RETURN_PATH, + ) + + assert response is not None + assert isinstance(response, dict) + assert "MessageId" in response + assert response["MessageId"] == "test_message_id" + + mock_async_client.send_raw_email.assert_called_once() diff --git a/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py new file mode 100644 index 0000000000000..b848f38c5b3e0 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/notifications/test_ses.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.notifications.ses import SesNotifier, send_ses_notification +from airflow.utils.types import NOTSET + +TEST_EMAIL_PARAMS = { + "mail_from": "from@test.com", + "to": "to@test.com", + "subject": "Test Subject", + "html_content": "

Test Content

", +} + +# The hook sets these default values if they are not provided +HOOK_DEFAULTS = { + "mime_charset": "utf-8", + "mime_subtype": "mixed", +} + + +class TestSesNotifier: + def test_class_and_notifier_are_same(self): + assert send_ses_notification is SesNotifier + + @pytest.mark.parametrize( + "aws_conn_id", + [ + pytest.param("aws_test_conn_id", id="custom-conn"), + pytest.param(None, id="none-conn"), + pytest.param(NOTSET, id="default-value"), + ], + ) + @pytest.mark.parametrize( + "region_name", + [ + pytest.param("eu-west-2", id="custom-region"), + pytest.param(None, id="no-region"), + pytest.param(NOTSET, id="default-value"), + ], + ) + def test_parameters_propagate_to_hook(self, aws_conn_id, region_name): + """Test notifier attributes propagate to SesHook.""" + notifier_kwargs = {} + if aws_conn_id is not NOTSET: + notifier_kwargs["aws_conn_id"] = aws_conn_id + if region_name is not NOTSET: + notifier_kwargs["region_name"] = region_name + + notifier = SesNotifier(**notifier_kwargs, **TEST_EMAIL_PARAMS) + with mock.patch("airflow.providers.amazon.aws.notifications.ses.SesHook") as mock_hook: + hook = notifier.hook + assert hook is notifier.hook, "Hook property not cached" + mock_hook.assert_called_once_with( + aws_conn_id=(aws_conn_id if aws_conn_id is not NOTSET else "aws_default"), + region_name=(region_name if region_name is not NOTSET else None), + ) + + # Basic check for notifier + notifier.notify({}) + mock_hook.return_value.send_email.assert_called_once_with(**TEST_EMAIL_PARAMS, **HOOK_DEFAULTS) + + @pytest.mark.asyncio + async def test_async_notify(self): + """Test async notification sends correctly.""" + notifier = SesNotifier(**TEST_EMAIL_PARAMS) + with mock.patch("airflow.providers.amazon.aws.notifications.ses.SesHook") as mock_hook: + mock_hook.return_value.asend_email = mock.AsyncMock() + + await notifier.async_notify({}) + + mock_hook.return_value.asend_email.assert_called_once_with(**TEST_EMAIL_PARAMS, **HOOK_DEFAULTS) + + def test_ses_notifier_with_optional_params(self): + """Test notifier handles all optional parameters correctly.""" + email_params = { + **TEST_EMAIL_PARAMS, + "files": ["test.txt"], + "cc": ["cc@test.com"], + "bcc": ["bcc@test.com"], + "mime_subtype": "alternative", + "mime_charset": "ascii", + "reply_to": "reply@test.com", + "return_path": "bounce@test.com", + "custom_headers": {"X-Custom": "value"}, + } + + notifier = SesNotifier(**email_params) + with mock.patch("airflow.providers.amazon.aws.notifications.ses.SesHook") as mock_hook: + notifier.notify({}) + + mock_hook.return_value.send_email.assert_called_once_with(**email_params) + + def test_ses_notifier_templated(self, create_dag_without_db): + """Test template fields are properly rendered.""" + templated_params = { + "aws_conn_id": "{{ dag.dag_id }}", + "region_name": "{{ var_region }}", + "mail_from": "{{ var_from }}", + "to": "{{ var_to }}", + "subject": "{{ var_subject }}", + "html_content": "Hello {{ var_name }}", + "cc": ["cc@{{ var_domain }}"], + "bcc": ["bcc@{{ var_domain }}"], + "reply_to": "reply@{{ var_domain }}", + } + + notifier = SesNotifier(**templated_params) + with mock.patch("airflow.providers.amazon.aws.notifications.ses.SesHook") as mock_hook: + context = { + "dag": create_dag_without_db("test_ses_notifier_templated"), + "var_region": "us-west-1", + "var_from": "from@example.com", + "var_to": "to@example.com", + "var_subject": "Test Email", + "var_name": "John", + "var_domain": "example.com", + } + notifier(context) + + mock_hook.assert_called_once_with( + aws_conn_id="test_ses_notifier_templated", + region_name="us-west-1", + ) + mock_hook.return_value.send_email.assert_called_once_with( + mail_from="from@example.com", + to="to@example.com", + subject="Test Email", + html_content="Hello John", + cc=["cc@example.com"], + bcc=["bcc@example.com"], + mime_subtype="mixed", + mime_charset="utf-8", + reply_to="reply@example.com", + )