Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ ds
dsl
Dsn
dsn
dsse
dst
dts
dttm
Expand Down
3 changes: 3 additions & 0 deletions providers/amazon/docs/operators/s3/s3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ Copy an Amazon S3 object
To copy an Amazon S3 object from one bucket to another you can use
:class:`~airflow.providers.amazon.aws.operators.s3.S3CopyObjectOperator`.
The Amazon S3 connection used here needs to have access to both source and destination bucket/key.
You can also specify server-side encryption using AWS KMS if you do not want to use destination buckets default key.
When using KMS, you must provide both the ``kms_key_id`` and ``kms_encryption_type`` parameters.
Ensure the role or user has the necessary permissions to use the key.

.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3.py
:language: python
Expand Down
13 changes: 13 additions & 0 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,8 @@ def copy_object(
source_version_id: str | None = None,
acl_policy: str | None = None,
meta_data_directive: str | None = None,
kms_key_id: str | None = None,
kms_encryption_type: str | None = None,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -1417,6 +1419,10 @@ def copy_object(
object to be copied which is private by default.
:param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it
with metadata that's provided in the request.
:param kms_key_id: The ARN, id or alias of the AWS KMS key to use for encrypting the destination object.
Required if using KMS-based server-side encryption with a non-default key.
:param kms_encryption_type: Type of KMS encryption to use for the object.
Can be either "aws:kms" (standard KMS) or "aws:kms:dsse" (double-shielded KMS).
"""
acl_policy = acl_policy or "private"
if acl_policy != NO_ACL:
Expand All @@ -1426,6 +1432,13 @@ def copy_object(
if self._requester_pays:
kwargs["RequestPayer"] = "requester"

if bool(kms_key_id) != bool(kms_encryption_type):
message = "kms_key_id and kms_encryption_type must both be specified. Only one was provided."
raise ValueError(message)
if kms_key_id and kms_encryption_type:
kwargs["SSEKMSKeyId"] = kms_key_id
kwargs["ServerSideEncryption"] = kms_encryption_type

dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,10 @@ class S3CopyObjectOperator(AwsBaseOperator[S3Hook]):
uploaded to the S3 bucket.
:param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with
metadata that's provided in the request.
:param kms_key_id: The ARN, id or alias of the AWS KMS key to use for encrypting the destination object.
Required if using KMS-based server-side encryption with a non-default key.
:param kms_encryption_type: Type of KMS encryption to use for the object.
Can be either "aws:kms" (standard KMS) or "aws:kms:dsse" (double-shielded KMS).
"""

template_fields: Sequence[str] = aws_template_fields(
Expand All @@ -316,6 +320,8 @@ def __init__(
source_version_id: str | None = None,
acl_policy: str | None = None,
meta_data_directive: str | None = None,
kms_key_id: str | None = None,
kms_encryption_type: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -327,6 +333,8 @@ def __init__(
self.source_version_id = source_version_id
self.acl_policy = acl_policy
self.meta_data_directive = meta_data_directive
self.kms_key_id = kms_key_id
self.kms_encryption_type = kms_encryption_type

def execute(self, context: Context):
self.hook.copy_object(
Expand All @@ -337,6 +345,8 @@ def execute(self, context: Context):
self.source_version_id,
self.acl_policy,
self.meta_data_directive,
self.kms_key_id,
self.kms_encryption_type,
)

def get_openlineage_facets_on_start(self):
Expand Down
42 changes: 42 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,48 @@ def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector)
RequestPayer="requester",
)

@mock_aws
def test_copy_object_with_kms_encryption(self, s3_bucket):
mock_hook = S3Hook()
with mock.patch.object(S3Hook, "get_conn") as get_conn:
mock_hook.copy_object(
"my_key",
"my_key_encrypted",
s3_bucket,
s3_bucket,
kms_key_id="arn:aws:kms:us-east-1:123456789012:key/abcd1234",
kms_encryption_type="aws:kms",
)
get_conn.return_value.copy_object.assert_called_once_with(
Bucket=s3_bucket,
Key="my_key_encrypted",
CopySource={"Bucket": s3_bucket, "Key": "my_key", "VersionId": None},
ACL="private",
SSEKMSKeyId="arn:aws:kms:us-east-1:123456789012:key/abcd1234",
ServerSideEncryption="aws:kms",
)

@mock_aws
def test_copy_object_with_kms_one_missing_raises(self, s3_bucket):
hook = S3Hook()

with pytest.raises(ValueError, match="kms_key_id and kms_encryption_type must both be specified"):
hook.copy_object(
source_bucket_key="my_key",
dest_bucket_key="my_key_copy",
source_bucket_name=s3_bucket,
dest_bucket_name=s3_bucket,
kms_key_id="arn:aws:kms:us-east-1:123456789012:key/abcd1234",
)
with pytest.raises(ValueError, match="kms_key_id and kms_encryption_type must both be specified"):
hook.copy_object(
source_bucket_key="my_key",
dest_bucket_key="my_key_copy",
source_bucket_name=s3_bucket,
dest_bucket_name=s3_bucket,
kms_encryption_type="aws:kms",
)

@mock_aws
def test_delete_bucket_if_bucket_exist(self, s3_bucket):
# assert if the bucket is created
Expand Down
47 changes: 47 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,53 @@ def test_template_fields(self):
)
validate_template_fields(operator)

@mock_aws
def test_s3_copy_object_with_kms(self, monkeypatch):
conn = boto3.client("s3")
conn.create_bucket(Bucket=self.source_bucket)
conn.create_bucket(Bucket=self.dest_bucket)
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=BytesIO(b"input"))
kms_key_id = "arn:aws:kms:us-east-1:123456789012:key/abcd1234"

def fake_copy_object(
self_hook,
source_bucket_key,
dest_bucket_key,
source_bucket_name=None,
dest_bucket_name=None,
source_version_id=None,
acl_policy=None,
meta_data_directive=None,
kms_key_id=None,
kms_encryption_type=None,
**kwargs,
):
copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key}
self_hook.get_conn().copy_object(
Bucket=dest_bucket_name,
Key=dest_bucket_key,
CopySource=copy_source,
SSEKMSKeyId=kms_key_id,
ServerSideEncryption=kms_encryption_type,
**kwargs,
)

monkeypatch.setattr(S3Hook, "copy_object", fake_copy_object)
op = S3CopyObjectOperator(
task_id="test_task_s3_copy_object_kms",
source_bucket_key=self.source_key,
source_bucket_name=self.source_bucket,
dest_bucket_key=self.dest_key,
dest_bucket_name=self.dest_bucket,
kms_key_id=kms_key_id,
kms_encryption_type="aws:kms",
)
op.execute(None)

objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
assert len(objects_in_dest_bucket["Contents"]) == 1
assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key


@mock_aws
class TestS3DeleteObjectsOperator:
Expand Down