diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 93c6f2645f467..cc9767243a79a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -580,6 +580,7 @@ ds dsl Dsn dsn +dsse dst dts dttm diff --git a/providers/amazon/docs/operators/s3/s3.rst b/providers/amazon/docs/operators/s3/s3.rst index 375c8a35d6fa8..2602eda66d8d7 100644 --- a/providers/amazon/docs/operators/s3/s3.rst +++ b/providers/amazon/docs/operators/s3/s3.rst @@ -122,6 +122,9 @@ Copy an Amazon S3 object To copy an Amazon S3 object from one bucket to another you can use :class:`~airflow.providers.amazon.aws.operators.s3.S3CopyObjectOperator`. The Amazon S3 connection used here needs to have access to both source and destination bucket/key. +You can also specify server-side encryption using AWS KMS if you do not want to use destination buckets default key. +When using KMS, you must provide both the ``kms_key_id`` and ``kms_encryption_type`` parameters. +Ensure the role or user has the necessary permissions to use the key. .. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3.py :language: python diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py index db5ab1fc8d8d2..24e0e012e6dc0 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/s3.py @@ -1386,6 +1386,8 @@ def copy_object( source_version_id: str | None = None, acl_policy: str | None = None, meta_data_directive: str | None = None, + kms_key_id: str | None = None, + kms_encryption_type: str | None = None, **kwargs, ) -> None: """ @@ -1417,6 +1419,10 @@ def copy_object( object to be copied which is private by default. :param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with metadata that's provided in the request. + :param kms_key_id: The ARN, id or alias of the AWS KMS key to use for encrypting the destination object. + Required if using KMS-based server-side encryption with a non-default key. + :param kms_encryption_type: Type of KMS encryption to use for the object. + Can be either "aws:kms" (standard KMS) or "aws:kms:dsse" (double-shielded KMS). """ acl_policy = acl_policy or "private" if acl_policy != NO_ACL: @@ -1426,6 +1432,13 @@ def copy_object( if self._requester_pays: kwargs["RequestPayer"] = "requester" + if bool(kms_key_id) != bool(kms_encryption_type): + message = "kms_key_id and kms_encryption_type must both be specified. Only one was provided." + raise ValueError(message) + if kms_key_id and kms_encryption_type: + kwargs["SSEKMSKeyId"] = kms_key_id + kwargs["ServerSideEncryption"] = kms_encryption_type + dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key( dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key" ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py index b75dfd40b0f58..d1e4f7848bbae 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py @@ -296,6 +296,10 @@ class S3CopyObjectOperator(AwsBaseOperator[S3Hook]): uploaded to the S3 bucket. :param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with metadata that's provided in the request. + :param kms_key_id: The ARN, id or alias of the AWS KMS key to use for encrypting the destination object. + Required if using KMS-based server-side encryption with a non-default key. + :param kms_encryption_type: Type of KMS encryption to use for the object. + Can be either "aws:kms" (standard KMS) or "aws:kms:dsse" (double-shielded KMS). """ template_fields: Sequence[str] = aws_template_fields( @@ -316,6 +320,8 @@ def __init__( source_version_id: str | None = None, acl_policy: str | None = None, meta_data_directive: str | None = None, + kms_key_id: str | None = None, + kms_encryption_type: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -327,6 +333,8 @@ def __init__( self.source_version_id = source_version_id self.acl_policy = acl_policy self.meta_data_directive = meta_data_directive + self.kms_key_id = kms_key_id + self.kms_encryption_type = kms_encryption_type def execute(self, context: Context): self.hook.copy_object( @@ -337,6 +345,8 @@ def execute(self, context: Context): self.source_version_id, self.acl_policy, self.meta_data_directive, + self.kms_key_id, + self.kms_encryption_type, ) def get_openlineage_facets_on_start(self): diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py index ef3c9fddb8a84..74c8eb3ab3b06 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_s3.py @@ -1311,6 +1311,48 @@ def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector) RequestPayer="requester", ) + @mock_aws + def test_copy_object_with_kms_encryption(self, s3_bucket): + mock_hook = S3Hook() + with mock.patch.object(S3Hook, "get_conn") as get_conn: + mock_hook.copy_object( + "my_key", + "my_key_encrypted", + s3_bucket, + s3_bucket, + kms_key_id="arn:aws:kms:us-east-1:123456789012:key/abcd1234", + kms_encryption_type="aws:kms", + ) + get_conn.return_value.copy_object.assert_called_once_with( + Bucket=s3_bucket, + Key="my_key_encrypted", + CopySource={"Bucket": s3_bucket, "Key": "my_key", "VersionId": None}, + ACL="private", + SSEKMSKeyId="arn:aws:kms:us-east-1:123456789012:key/abcd1234", + ServerSideEncryption="aws:kms", + ) + + @mock_aws + def test_copy_object_with_kms_one_missing_raises(self, s3_bucket): + hook = S3Hook() + + with pytest.raises(ValueError, match="kms_key_id and kms_encryption_type must both be specified"): + hook.copy_object( + source_bucket_key="my_key", + dest_bucket_key="my_key_copy", + source_bucket_name=s3_bucket, + dest_bucket_name=s3_bucket, + kms_key_id="arn:aws:kms:us-east-1:123456789012:key/abcd1234", + ) + with pytest.raises(ValueError, match="kms_key_id and kms_encryption_type must both be specified"): + hook.copy_object( + source_bucket_key="my_key", + dest_bucket_key="my_key_copy", + source_bucket_name=s3_bucket, + dest_bucket_name=s3_bucket, + kms_encryption_type="aws:kms", + ) + @mock_aws def test_delete_bucket_if_bucket_exist(self, s3_bucket): # assert if the bucket is created diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py index ffd43a59c0c89..e9f7204b71b08 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py @@ -591,6 +591,53 @@ def test_template_fields(self): ) validate_template_fields(operator) + @mock_aws + def test_s3_copy_object_with_kms(self, monkeypatch): + conn = boto3.client("s3") + conn.create_bucket(Bucket=self.source_bucket) + conn.create_bucket(Bucket=self.dest_bucket) + conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=BytesIO(b"input")) + kms_key_id = "arn:aws:kms:us-east-1:123456789012:key/abcd1234" + + def fake_copy_object( + self_hook, + source_bucket_key, + dest_bucket_key, + source_bucket_name=None, + dest_bucket_name=None, + source_version_id=None, + acl_policy=None, + meta_data_directive=None, + kms_key_id=None, + kms_encryption_type=None, + **kwargs, + ): + copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key} + self_hook.get_conn().copy_object( + Bucket=dest_bucket_name, + Key=dest_bucket_key, + CopySource=copy_source, + SSEKMSKeyId=kms_key_id, + ServerSideEncryption=kms_encryption_type, + **kwargs, + ) + + monkeypatch.setattr(S3Hook, "copy_object", fake_copy_object) + op = S3CopyObjectOperator( + task_id="test_task_s3_copy_object_kms", + source_bucket_key=self.source_key, + source_bucket_name=self.source_bucket, + dest_bucket_key=self.dest_key, + dest_bucket_name=self.dest_bucket, + kms_key_id=kms_key_id, + kms_encryption_type="aws:kms", + ) + op.execute(None) + + objects_in_dest_bucket = conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key) + assert len(objects_in_dest_bucket["Contents"]) == 1 + assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key + @mock_aws class TestS3DeleteObjectsOperator: