Skip to content

Commit

Permalink
Fix avoidable S3 race condition (piskvorky#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
RachitSharma2001 committed Nov 5, 2022
1 parent 4268a1a commit 940153e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
35 changes: 24 additions & 11 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,18 +1189,31 @@ def iter_bucket(

with smart_open.concurrency.create_pool(processes=workers) as pool:
result_iterator = pool.imap_unordered(download_key, key_iterator)
for key_no, (key, content) in enumerate(result_iterator):
if True or key_no % 1000 == 0:
logger.info(
"yielding key #%i: %s, size %i (total %.1fMB)",
key_no, key, len(content), total_size / 1024.0 ** 2
)
yield key, content
total_size += len(content)

if key_limit is not None and key_no + 1 >= key_limit:
# we were asked to output only a limited number of keys => we're done
key_no = 0
while True:
try:
(key, content) = result_iterator.__next__()
if True or key_no % 1000 == 0:
logger.info(
"yielding key #%i: %s, size %i (total %.1fMB)",
key_no, key, len(content), total_size / 1024.0 ** 2
)
yield key, content
total_size += len(content)
if key_limit is not None and key_no + 1 >= key_limit:
# we were asked to output only a limited number of keys => we're done
break
except botocore.exceptions.ClientError as err:
if 'Error' in err.response and err.response['Error'].get('Code') == '404':
logger.warning(
"Encountered '404 Not Found' error for key #%i",
key_no
)
else:
raise err
except StopIteration:
break
key_no += 1
logger.info("processed %i keys, total size %i" % (key_no + 1, total_size))


Expand Down
31 changes: 22 additions & 9 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Radim Rehurek <me@radimrehurek.com>
#
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
from collections import defaultdict
import functools
import gzip
Expand Down Expand Up @@ -404,7 +397,6 @@ def test_read_empty_file(self):
class MultipartWriterTest(unittest.TestCase):
"""
Test writing into s3 files.
"""
def setUp(self):
ignore_resource_warnings()
Expand Down Expand Up @@ -559,7 +551,6 @@ def test_writebuffer(self):
class SinglepartWriterTest(unittest.TestCase):
"""
Test writing into s3 files using single part upload.
"""
def setUp(self):
ignore_resource_warnings()
Expand Down Expand Up @@ -681,6 +672,28 @@ def test_iter_bucket(self):
results = list(smart_open.s3.iter_bucket(BUCKET_NAME))
self.assertEqual(len(results), 10)

def test_iter_bucket_404(self):
populate_bucket()

def throw_404_error_for_key_4(*args):
if args[1] == "key_4":
raise botocore.exceptions.ClientError(
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
operation_name="HeadObject",
)
else:
return [0]

with mock.patch("smart_open.s3._download_fileobj", side_effect=throw_404_error_for_key_4):
results = list(smart_open.s3.iter_bucket(BUCKET_NAME))
self.assertEqual(len(results), 9)

def test_iter_bucket_non_404(self):
populate_bucket()
with mock.patch("smart_open.s3._download_fileobj", side_effect=ARBITRARY_CLIENT_ERROR):
with pytest.raises(botocore.exceptions.ClientError):
list(smart_open.s3.iter_bucket(BUCKET_NAME))

def test_deprecated_top_level_s3_iter_bucket(self):
populate_bucket()
with self.assertLogs(smart_open.logger.name, level='WARN') as cm:
Expand Down

0 comments on commit 940153e

Please sign in to comment.