⚡️ Speed up method S3DataSource.put_object_retention by 2,035%
#632
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 2,035% (20.35x) speedup for
S3DataSource.put_object_retentioninbackend/python/app/sources/external/s3/s3.py⏱️ Runtime :
18.4 milliseconds→860 microseconds(best of240runs)📝 Explanation and details
The optimized code achieves a 2035% speedup by eliminating the expensive
async with session.client('s3')context manager overhead that was creating and tearing down S3 clients for every operation.Key optimization: The original code called
session.client('s3')inside an async context manager for eachput_object_retentioncall, which incurred significant overhead from:The optimized version replaces this with a direct call to
await self._get_s3_client(), which maintains a persistent S3 client connection that can be reused across multiple operations.Performance impact: The line profiler shows the bottleneck was in the
async with session.client('s3')line, consuming 86% of execution time (65.3ms out of 75.9ms total). The optimized version eliminates this entirely, reducing total execution time from 18.4ms to 860μs.Throughput benefits: Operations per second improved from 96,525 to 162,000 (67.8% improvement), making this optimization particularly valuable for:
The optimization excels in all test scenarios, with particularly strong gains in throughput tests (small/medium/high load) and concurrent execution patterns, where connection reuse provides compounding benefits. This is especially important for S3 operations that are typically called repeatedly in data processing pipelines.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import asyncio # used to run async functions
from typing import Any, Dict, Optional
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
class ClientError(Exception):
"""Mock ClientError exception for testing."""
def init(self, response, operation_name):
super().init(response.get('Error', {}).get('Message', 'ClientError'))
self.response = response
self.operation_name = operation_name
Mock aioboto3 Session and S3Client
class MockS3Client:
"""Mock S3 client with put_object_retention method."""
def init(self, behavior=None):
# behavior: dict of (Bucket, Key) -> response or exception
self.behavior = behavior if behavior is not None else {}
class MockSession:
"""Mock aioboto3 session returning a mock S3 client."""
def init(self, behavior=None):
self.behavior = behavior
class MockS3ClientBuilder:
"""Mock S3Client for S3DataSource."""
def init(self, behavior=None):
self.session = MockSession(behavior)
---- Tests ----
@pytest.mark.asyncio
async def test_put_object_retention_basic_success():
"""Basic: Test successful put_object_retention with required params only."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
# Await the function with minimal required params
resp = await s3ds.put_object_retention(Bucket="mybucket", Key="mykey")
@pytest.mark.asyncio
async def test_put_object_retention_with_all_params():
"""Basic: Test with all optional parameters set."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
retention = {"Mode": "GOVERNANCE", "RetainUntilDate": "2024-12-31T23:59:59Z"}
resp = await s3ds.put_object_retention(
Bucket="bucket",
Key="key",
Retention=retention,
RequestPayer="requester",
VersionId="ver123",
BypassGovernanceRetention=True,
ChecksumAlgorithm="SHA256",
ExpectedBucketOwner="ownerid"
)
@pytest.mark.asyncio
async def test_put_object_retention_empty_response():
"""Edge: Simulate S3 returning None (empty response)."""
behavior = {("emptybucket", "emptykey"): None}
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
resp = await s3ds.put_object_retention(Bucket="emptybucket", Key="emptykey")
@pytest.mark.asyncio
async def test_put_object_retention_error_response_dict():
"""Edge: S3 returns a dict with Error field."""
error_dict = {'Error': {'Code': 'AccessDenied', 'Message': 'You do not have permission'}}
behavior = {("errbucket", "errkey"): error_dict}
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
resp = await s3ds.put_object_retention(Bucket="errbucket", Key="errkey")
@pytest.mark.asyncio
async def test_put_object_retention_clienterror_exception():
"""Edge: S3 client raises ClientError exception."""
error_response = {'Error': {'Code': 'NoSuchBucket', 'Message': 'Bucket does not exist'}}
behavior = {("badbucket", "badkey"): ClientError(error_response, "PutObjectRetention")}
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
resp = await s3ds.put_object_retention(Bucket="badbucket", Key="badkey")
@pytest.mark.asyncio
async def test_put_object_retention_generic_exception():
"""Edge: S3 client raises a generic exception."""
class Boom(Exception): pass
behavior = {("boom", "key"): Boom("Something went wrong")}
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
resp = await s3ds.put_object_retention(Bucket="boom", Key="key")
@pytest.mark.asyncio
async def test_put_object_retention_concurrent_success():
"""Edge: Test concurrent execution of multiple successful calls."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
# Launch several calls concurrently
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(10)
]
results = await asyncio.gather(*tasks)
for i, resp in enumerate(results):
pass
@pytest.mark.asyncio
async def test_put_object_retention_concurrent_mixed():
"""Edge: Concurrent calls with mixed success/error."""
behavior = {
("bucket0", "key0"): ClientError({'Error': {'Code': 'NoSuchBucket', 'Message': 'Bucket missing'}}, "PutObjectRetention"),
("bucket1", "key1"): None,
("bucket2", "key2"): {'Error': {'Code': 'AccessDenied', 'Message': 'Denied'}},
# others succeed
}
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(5)
]
results = await asyncio.gather(*tasks)
# bucket3, bucket4: success
for i in [3, 4]:
pass
@pytest.mark.asyncio
async def test_put_object_retention_large_scale_concurrent():
"""Large Scale: Test 100 concurrent successful requests."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(100)
]
results = await asyncio.gather(*tasks)
for i, resp in enumerate(results):
pass
@pytest.mark.asyncio
async def test_put_object_retention_large_scale_mixed():
"""Large Scale: Test 50 concurrent requests with some errors."""
behavior = {}
# 10 error responses, 10 None, 10 ClientError, 20 success
for i in range(10):
behavior[(f"bucket{i}", f"key{i}")] = {'Error': {'Code': 'AccessDenied', 'Message': 'Denied'}}
for i in range(10, 20):
behavior[(f"bucket{i}", f"key{i}")] = None
for i in range(20, 30):
behavior[(f"bucket{i}", f"key{i}")] = ClientError({'Error': {'Code': 'NoSuchBucket', 'Message': 'Missing'}}, "PutObjectRetention")
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(50)
]
results = await asyncio.gather(*tasks)
# First 10: error dict
for i in range(10):
pass
# Next 10: None
for i in range(10, 20):
pass
# Next 10: ClientError
for i in range(20, 30):
pass
# Last 20: success
for i in range(30, 50):
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_small_load():
"""Throughput: Test performance under small load (10 requests)."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(10)
]
results = await asyncio.gather(*tasks)
for resp in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_medium_load():
"""Throughput: Test performance under medium load (50 requests)."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(50)
]
results = await asyncio.gather(*tasks)
for resp in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_high_load():
"""Throughput: Test performance under high load (200 requests)."""
s3_client = MockS3ClientBuilder()
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(200)
]
results = await asyncio.gather(*tasks)
for resp in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_mixed_load():
"""Throughput: Test performance under mixed load (some errors, some success)."""
behavior = {}
for i in range(10):
behavior[(f"bucket{i}", f"key{i}")] = {'Error': {'Code': 'AccessDenied', 'Message': 'Denied'}}
for i in range(10, 20):
behavior[(f"bucket{i}", f"key{i}")] = None
for i in range(20, 30):
behavior[(f"bucket{i}", f"key{i}")] = ClientError({'Error': {'Code': 'NoSuchBucket', 'Message': 'Missing'}}, "PutObjectRetention")
s3_client = MockS3ClientBuilder(behavior)
s3ds = S3DataSource(s3_client)
tasks = [
s3ds.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i in range(40)
]
results = await asyncio.gather(*tasks)
# First 10: error dict
for i in range(10):
pass
# Next 10: None
for i in range(10, 20):
pass
# Next 10: ClientError
for i in range(20, 30):
pass
# Last 10: success
for i in range(30, 40):
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import asyncio # used to run async functions
from unittest.mock import AsyncMock, MagicMock, patch
import pytest # used for our unit tests
from app.sources.external.s3.s3 import S3DataSource
--- Begin: S3Response and S3Client stubs for testing ---
class S3Response:
"""Simple S3Response stub for testing."""
def init(self, success: bool, data=None, error=None):
self.success = success
self.data = data
self.error = error
class S3Client:
"""Stub for S3Client for testing."""
def init(self, session):
self._session = session
--- End: S3Response and S3Client stubs for testing ---
--- Begin: S3DataSource function under test (EXACT COPY) ---
try:
import aioboto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
except ImportError:
class ClientError(Exception):
def init(self, response, operation_name):
super().init(response.get('Error', {}).get('Message', 'Unknown'))
self.response = response
--- End: S3DataSource function under test ---
--- Begin: Mock aioboto3 session and client for async context ---
class DummyS3ClientAsyncContext:
"""Dummy async context manager for s3_client."""
def init(self, put_object_retention_result=None, raise_client_error=False, raise_general_error=False):
self.put_object_retention_result = put_object_retention_result
self.raise_client_error = raise_client_error
self.raise_general_error = raise_general_error
class DummySession:
"""Dummy aioboto3 session for async context."""
def init(self, put_object_retention_result=None, raise_client_error=False, raise_general_error=False):
self.put_object_retention_result = put_object_retention_result
self.raise_client_error = raise_client_error
self.raise_general_error = raise_general_error
--- End: Mock aioboto3 session and client ---
=========================
BASIC TEST CASES
=========================
@pytest.mark.asyncio
async def test_put_object_retention_basic_success():
"""Basic: Should return success S3Response for valid input and successful S3 call."""
dummy_response = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
session = DummySession(put_object_retention_result=dummy_response)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(Bucket="test-bucket", Key="test-key")
@pytest.mark.asyncio
async def test_put_object_retention_basic_error_in_response():
"""Basic: Should handle S3 error response dict and return error S3Response."""
error_response = {"Error": {"Code": "AccessDenied", "Message": "Permission denied"}}
session = DummySession(put_object_retention_result=error_response)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(Bucket="bucket", Key="key")
@pytest.mark.asyncio
async def test_put_object_retention_basic_none_response():
"""Basic: Should handle None response and return error S3Response."""
session = DummySession(put_object_retention_result=None)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(Bucket="bucket", Key="key")
@pytest.mark.asyncio
async def test_put_object_retention_basic_with_all_parameters():
"""Basic: Should accept all parameters and pass them to S3 client."""
dummy_response = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
session = DummySession(put_object_retention_result=dummy_response)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(
Bucket="bucket",
Key="key",
Retention={"Mode": "GOVERNANCE", "RetainUntilDate": "2024-12-31T23:59:59Z"},
RequestPayer="requester",
VersionId="12345",
BypassGovernanceRetention=True,
ChecksumAlgorithm="SHA256",
ExpectedBucketOwner="ownerid"
)
=========================
EDGE TEST CASES
=========================
@pytest.mark.asyncio
async def test_put_object_retention_edge_client_error():
"""Edge: Should handle ClientError exception and return error S3Response."""
session = DummySession(raise_client_error=True)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(Bucket="bucket", Key="key")
@pytest.mark.asyncio
async def test_put_object_retention_edge_general_exception():
"""Edge: Should handle unexpected exception and return error S3Response."""
session = DummySession(raise_general_error=True)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
result = await datasource.put_object_retention(Bucket="bucket", Key="key")
@pytest.mark.asyncio
async def test_put_object_retention_edge_concurrent_calls():
"""Edge: Should handle multiple concurrent calls with different responses."""
# Two sessions with different results
session1 = DummySession(put_object_retention_result={"ResponseMetadata": {"HTTPStatusCode": 200}})
session2 = DummySession(put_object_retention_result={"Error": {"Code": "NotFound", "Message": "Object not found"}})
s3_client1 = S3Client(session1)
s3_client2 = S3Client(session2)
datasource1 = S3DataSource(s3_client1)
datasource2 = S3DataSource(s3_client2)
results = await asyncio.gather(
datasource1.put_object_retention(Bucket="bucket1", Key="key1"),
datasource2.put_object_retention(Bucket="bucket2", Key="key2")
)
@pytest.mark.asyncio
async def test_put_object_retention_edge_invalid_bucket_key_types():
"""Edge: Should raise error if Bucket or Key is not str (handled by S3 client, not by this wrapper)."""
# The function does not type-check, but S3 client may error; simulate error
session = DummySession(raise_general_error=True)
s3_client = S3Client(session)
datasource = S3DataSource(s3_client)
# Pass int instead of str for Bucket
result = await datasource.put_object_retention(Bucket=123, Key="key")
=========================
LARGE SCALE TEST CASES
=========================
@pytest.mark.asyncio
async def test_put_object_retention_large_scale_concurrent_success():
"""Large Scale: Should handle many concurrent successful calls."""
dummy_response = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
sessions = [DummySession(put_object_retention_result=dummy_response) for _ in range(20)]
datasources = [S3DataSource(S3Client(session)) for session in sessions]
tasks = [
datasource.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i, datasource in enumerate(datasources)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_large_scale_concurrent_mixed():
"""Large Scale: Should handle many concurrent calls with mixed success/error."""
dummy_success = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
dummy_error = {"Error": {"Code": "AccessDenied", "Message": "Permission denied"}}
sessions = []
for i in range(20):
if i % 2 == 0:
sessions.append(DummySession(put_object_retention_result=dummy_success))
else:
sessions.append(DummySession(put_object_retention_result=dummy_error))
datasources = [S3DataSource(S3Client(session)) for session in sessions]
tasks = [
datasource.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i, datasource in enumerate(datasources)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
if i % 2 == 0:
pass
else:
pass
=========================
THROUGHPUT TEST CASES
=========================
@pytest.mark.asyncio
async def test_put_object_retention_throughput_small_load():
"""Throughput: Small load (5 concurrent calls), all succeed."""
dummy_response = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
sessions = [DummySession(put_object_retention_result=dummy_response) for _ in range(5)]
datasources = [S3DataSource(S3Client(session)) for session in sessions]
tasks = [
datasource.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i, datasource in enumerate(datasources)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_medium_load():
"""Throughput: Medium load (50 concurrent calls), all succeed."""
dummy_response = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
sessions = [DummySession(put_object_retention_result=dummy_response) for _ in range(50)]
datasources = [S3DataSource(S3Client(session)) for session in sessions]
tasks = [
datasource.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i, datasource in enumerate(datasources)
]
results = await asyncio.gather(*tasks)
for result in results:
pass
@pytest.mark.asyncio
async def test_put_object_retention_throughput_high_volume():
"""Throughput: High volume (100 concurrent calls), mixed success/error."""
dummy_success = {"ResponseMetadata": {"HTTPStatusCode": 200}, "Success": True}
dummy_error = {"Error": {"Code": "AccessDenied", "Message": "Permission denied"}}
sessions = []
for i in range(100):
if i % 10 == 0:
sessions.append(DummySession(put_object_retention_result=dummy_error))
else:
sessions.append(DummySession(put_object_retention_result=dummy_success))
datasources = [S3DataSource(S3Client(session)) for session in sessions]
tasks = [
datasource.put_object_retention(Bucket=f"bucket{i}", Key=f"key{i}")
for i, datasource in enumerate(datasources)
]
results = await asyncio.gather(*tasks)
for i, result in enumerate(results):
if i % 10 == 0:
pass
else:
pass
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-S3DataSource.put_object_retention-mhxbh6wkand push.