Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reading empty file or seeking past end of file for s3 backend #549

Merged
merged 3 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
_SLEEP_SECONDS = 10

# Returned by AWS when we try to seek beyond EOF.
_OUT_OF_RANGE = 'Requested Range Not Satisfiable'
_OUT_OF_RANGE = 'InvalidRange'


def parse_uri(uri_as_string):
Expand Down Expand Up @@ -385,18 +385,9 @@ def _open_body(self, start=None, stop=None):
except IOError as ioe:
# Handle requested content range exceeding content size.
error_response = _unwrap_ioerror(ioe)
if error_response is None or error_response.get('Message') != _OUT_OF_RANGE:
if error_response is None or error_response.get('Code') != _OUT_OF_RANGE:
raise
try:
self._position = self._content_length = int(error_response['ActualObjectSize'])
except KeyError:
# This shouldn't happen with real S3, but moto lacks ActualObjectSize.
# Reported at https://github.com/spulec/moto/issues/2981
self._position = self._content_length = _get(
self._object,
version=self._version_id,
**self._object_kwargs,
)['ContentLength']
self._position = self._content_length = int(error_response['ActualObjectSize'])
self._body = io.BytesIO()
else:
units, start, stop, length = smart_open.utils.parse_content_range(response['ContentRange'])
Expand Down
45 changes: 43 additions & 2 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,27 @@ def ignore_resource_warnings():
warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*<ssl.SSLSocket.*>") # noqa


@contextmanager
def patch_invalid_range_response(actual_size):
""" Work around a bug in moto (https://github.com/spulec/moto/issues/2981) where the
API response doesn't match when requesting an invalid range of bytes from an S3 GetObject. """
_real_get = smart_open.s3._get

def mock_get(*args, **kwargs):
try:
return _real_get(*args, **kwargs)
except IOError as ioe:
error_response = smart_open.s3._unwrap_ioerror(ioe)
if error_response and error_response.get('Message') == 'Requested Range Not Satisfiable':
error_response['ActualObjectSize'] = actual_size
error_response['Code'] = 'InvalidRange'
error_response['Message'] = 'The requested range is not satisfiable'
raise

with patch('smart_open.s3._get', new=mock_get):
yield


class BaseTest(unittest.TestCase):
@contextmanager
def assertApiCalls(self, **expected_api_calls):
Expand Down Expand Up @@ -236,6 +257,15 @@ def test_seek_end(self):
self.assertEqual(seek, len(content) - 4)
self.assertEqual(fin.read(), b'you?')

def test_seek_past_end(self):
content = u"hello wořld\nhow are you?".encode('utf8')
put_to_bucket(contents=content)

with self.assertApiCalls(GetObject=1), patch_invalid_range_response(str(len(content))):
fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME, defer_seek=True)
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
seek = fin.seek(60)
self.assertEqual(seek, len(content))

def test_detect_eof(self):
content = u"hello wořld\nhow are you?".encode('utf8')
put_to_bucket(contents=content)
Expand Down Expand Up @@ -352,6 +382,15 @@ def test_defer_seek(self):
fin.seek(10)
self.assertEqual(fin.read(), content[10:])

def test_read_empty_file(self):
put_to_bucket(contents=b'')

with self.assertApiCalls(GetObject=1), patch_invalid_range_response('0'):
with smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) as fin:
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
data = fin.read()

self.assertEqual(data, b'')


@moto.mock_s3
class MultipartWriterTest(unittest.TestCase):
Expand Down Expand Up @@ -426,7 +465,8 @@ def test_write_04(self):
pass

# read back the same key and check its content
output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'rb'))
with patch_invalid_range_response('0'):
output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'rb'))

self.assertEqual(output, [])

Expand Down Expand Up @@ -548,7 +588,8 @@ def test_write_04(self):
pass

# read back the same key and check its content
output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'rb'))
with patch_invalid_range_response('0'):
output = list(smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'rb'))
self.assertEqual(output, [])

def test_buffered_writer_wrapper_works(self):
Expand Down
10 changes: 6 additions & 4 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from smart_open import smart_open_lib
from smart_open import webhdfs
from smart_open.smart_open_lib import patch_pathlib, _patch_pathlib
from smart_open.tests.test_s3 import patch_invalid_range_response

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -758,11 +759,12 @@ def test_readline_eof(self):
with smart_open.open("s3://mybucket/mykey", "wb"):
pass

reader = smart_open.open("s3://mybucket/mykey", "rb")
with patch_invalid_range_response('0'):
reader = smart_open.open("s3://mybucket/mykey", "rb")

self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")
self.assertEqual(reader.readline(), b"")

@mock_s3
def test_s3_iter_lines(self):
Expand Down