Skip to content

Commit

Permalink
Add MRAP support to CRT transfer manager (#319)
Browse files Browse the repository at this point in the history
  • Loading branch information
nateprewitt authored Nov 20, 2024
1 parent f250895 commit 7b8a5cd
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3-25519.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``s3``",
"description": "Added Multi-Region Access Points support to CRT transfers."
}
53 changes: 52 additions & 1 deletion s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import logging
import re
import threading
from io import BytesIO

Expand All @@ -36,6 +37,7 @@
from botocore.compat import urlsplit
from botocore.config import Config
from botocore.exceptions import NoCredentialsError
from botocore.utils import ArnParser, InvalidArnException

from s3transfer.constants import MB
from s3transfer.exceptions import TransferNotDoneError
Expand Down Expand Up @@ -874,7 +876,18 @@ def _default_get_make_request_args(
x.title() for x in request_type.split('_')
)

if is_s3express_bucket(call_args.bucket):
arn_handler = _S3ArnParamHandler()
if (
accesspoint_arn_details := arn_handler.handle_arn(call_args.bucket)
) and accesspoint_arn_details['region'] == "":
# Configure our region to `*` to propogate in `x-amz-region-set`
# for multi-region support in MRAP accesspoints.
make_request_args['signing_config'] = AwsSigningConfig(
algorithm=AwsSigningAlgorithm.V4_ASYMMETRIC,
region="*",
)
call_args.bucket = accesspoint_arn_details['resource_name']
elif is_s3express_bucket(call_args.bucket):
make_request_args['signing_config'] = AwsSigningConfig(
algorithm=AwsSigningAlgorithm.V4_S3EXPRESS
)
Expand Down Expand Up @@ -917,3 +930,41 @@ def __init__(self, fileobj):

def __call__(self, chunk, **kwargs):
self._fileobj.write(chunk)


class _S3ArnParamHandler:
"""Partial port of S3ArnParamHandler from botocore.
This is used to make a determination on MRAP accesspoints for signing
purposes. This should be safe to remove once we properly integrate auth
resolution from Botocore into the CRT transfer integration.
"""

_RESOURCE_REGEX = re.compile(
r'^(?P<resource_type>accesspoint|outpost)[/:](?P<resource_name>.+)$'
)

def __init__(self):
self._arn_parser = ArnParser()

def handle_arn(self, bucket):
arn_details = self._get_arn_details_from_bucket(bucket)
if arn_details is None:
return
if arn_details['resource_type'] == 'accesspoint':
return arn_details

def _get_arn_details_from_bucket(self, bucket):
try:
arn_details = self._arn_parser.parse_arn(bucket)
self._add_resource_type_and_name(arn_details)
return arn_details
except InvalidArnException:
pass
return None

def _add_resource_type_and_name(self, arn_details):
match = self._RESOURCE_REGEX.match(arn_details['resource'])
if match:
arn_details['resource_type'] = match.group('resource_type')
arn_details['resource_name'] = match.group('resource_name')
60 changes: 56 additions & 4 deletions tests/functional/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def setUp(self):
self.region = 'us-west-2'
self.bucket = "test_bucket"
self.s3express_bucket = 's3expressbucket--usw2-az5--x-s3'
self.mrap_accesspoint = (
'arn:aws:s3::123456789012:accesspoint/mfzwi23gnjvgw.mrap'
)
self.mrap_bucket = 'mfzwi23gnjvgw.mrap'
self.key = "test_key"
self.expected_content = b'my content'
self.expected_download_content = b'new content'
Expand All @@ -80,6 +84,10 @@ def setUp(self):
self.expected_host = f"s3.{self.region}.amazonaws.com"
self.expected_s3express_host = f'{self.s3express_bucket}.s3express-usw2-az5.us-west-2.amazonaws.com'
self.expected_s3express_path = f'/{self.key}'
self.expected_mrap_host = (
f'{self.mrap_bucket}.accesspoint.s3-global.amazonaws.com'
)
self.expected_mrap_path = f"/{self.key}"
self.s3_request = mock.Mock(awscrt.s3.S3Request)
self.s3_crt_client = mock.Mock(awscrt.s3.S3Client)
self.s3_crt_client.make_request.side_effect = (
Expand Down Expand Up @@ -137,7 +145,7 @@ def _assert_expected_crt_http_request(
for expected_missing_header in expected_missing_headers:
self.assertNotIn(expected_missing_header.lower(), header_names)

def _assert_exected_s3express_request(
def _assert_expected_s3express_request(
self, make_request_kwargs, expected_http_method='GET'
):
self._assert_expected_crt_http_request(
Expand All @@ -152,6 +160,22 @@ def _assert_exected_s3express_request(
awscrt.auth.AwsSigningAlgorithm.V4_S3EXPRESS,
)

def _assert_expected_mrap_request(
self, make_request_kwargs, expected_http_method='GET'
):
self._assert_expected_crt_http_request(
make_request_kwargs["request"],
expected_host=self.expected_mrap_host,
expected_path=self.expected_mrap_path,
expected_http_method=expected_http_method,
)
self.assertIn('signing_config', make_request_kwargs)
self.assertEqual(
make_request_kwargs['signing_config'].algorithm,
awscrt.auth.AwsSigningAlgorithm.V4_ASYMMETRIC,
)
self.assertEqual(make_request_kwargs['signing_config'].region, "*")

def _assert_subscribers_called(self, expected_future=None):
self.assertTrue(self.record_subscriber.on_queued_called)
self.assertTrue(self.record_subscriber.on_done_called)
Expand Down Expand Up @@ -404,7 +428,21 @@ def test_upload_with_s3express(self):
[self.record_subscriber],
)
future.result()
self._assert_exected_s3express_request(
self._assert_expected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='PUT',
)

def test_upload_with_mrap(self):
future = self.transfer_manager.upload(
self.filename,
self.mrap_accesspoint,
self.key,
{},
[self.record_subscriber],
)
future.result()
self._assert_expected_mrap_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='PUT',
)
Expand Down Expand Up @@ -532,7 +570,21 @@ def test_download_with_s3express(self):
[self.record_subscriber],
)
future.result()
self._assert_exected_s3express_request(
self._assert_expected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='GET',
)

def test_download_with_mrap(self):
future = self.transfer_manager.download(
self.mrap_accesspoint,
self.key,
self.filename,
{},
[self.record_subscriber],
)
future.result()
self._assert_expected_mrap_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='GET',
)
Expand Down Expand Up @@ -577,7 +629,7 @@ def test_delete_with_s3express(self):
self.s3express_bucket, self.key, {}, [self.record_subscriber]
)
future.result()
self._assert_exected_s3express_request(
self._assert_expected_s3express_request(
self.s3_crt_client.make_request.call_args[1],
expected_http_method='DELETE',
)
Expand Down

0 comments on commit 7b8a5cd

Please sign in to comment.