Skip to content

Commit

Permalink
Be able to remove ACL in S3 hook's copy_object (#40518)
Browse files Browse the repository at this point in the history
* Be able to remove ACL in S3 hook's copy_object

* Add default test as well

* Update s3.py
  • Loading branch information
anteverse authored Jul 2, 2024
1 parent 015ac89 commit 8e04ef0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
10 changes: 9 additions & 1 deletion airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@
logger = logging.getLogger(__name__)


# Explicit value that would remove ACLs from a copy
# No conflicts with Canned ACLs:
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
NO_ACL = "no-acl"


def provide_bucket_name(func: Callable) -> Callable:
"""Provide a bucket name taken from the connection if no bucket name has been passed to the function."""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
Expand Down Expand Up @@ -1285,6 +1291,8 @@ def copy_object(
object to be copied which is private by default.
"""
acl_policy = acl_policy or "private"
if acl_policy != NO_ACL:
kwargs["ACL"] = acl_policy

dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
Expand All @@ -1296,7 +1304,7 @@ def copy_object(

copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
response = self.get_conn().copy_object(
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy, **kwargs
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs
)
return response

Expand Down
37 changes: 37 additions & 0 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.models import Connection
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.s3 import (
NO_ACL,
S3Hook,
provide_bucket_name,
unify_bucket_name_and_key,
Expand Down Expand Up @@ -990,6 +991,42 @@ def test_copy_object_acl(self, s3_bucket, tmp_path):
assert response["Grants"][0]["Permission"] == "FULL_CONTROL"
assert len(response["Grants"]) == 1

@mock_aws
def test_copy_object_no_acl(
self,
s3_bucket,
):
mock_hook = S3Hook()

with mock.patch.object(
S3Hook,
"get_conn",
) as patched_get_conn:
mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket, acl_policy=NO_ACL)

# Check we're not passing ACLs
patched_get_conn.return_value.copy_object.assert_called_once_with(
Bucket="airflow-test-s3-bucket",
Key="my_key3",
CopySource={"Bucket": "airflow-test-s3-bucket", "Key": "my_key", "VersionId": None},
)
patched_get_conn.reset_mock()

mock_hook.copy_object(
"my_key",
"my_key3",
s3_bucket,
s3_bucket,
)

# Check the default is "private"
patched_get_conn.return_value.copy_object.assert_called_once_with(
Bucket="airflow-test-s3-bucket",
Key="my_key3",
CopySource={"Bucket": "airflow-test-s3-bucket", "Key": "my_key", "VersionId": None},
ACL="private",
)

@mock_aws
def test_delete_bucket_if_bucket_exist(self, s3_bucket):
# assert if the bucket is created
Expand Down

0 comments on commit 8e04ef0

Please sign in to comment.