diff --git a/smart_open/s3.py b/smart_open/s3.py index da6506b9..85ecf523 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -75,6 +75,21 @@ def open(bucket_id, key_id, mode, **kwargs): class RawReader(object): """Read an S3 object.""" + def __init__(self, s3_object): + self.position = 0 + self._object = s3_object + self._body = s3_object.get()['Body'] + + def read(self, size=-1): + if size == -1: + return self._body.read() + return self._body.read(size) + + +class SeekableRawReader(object): + """Read an S3 object. + + Support seeking around, but is slower than RawReader.""" def __init__(self, s3_object): self.position = 0 self._object = s3_object @@ -95,10 +110,6 @@ def read(self, size=-1): class BufferedInputBase(io.BufferedIOBase): - """Reads bytes from S3. - - Implements the io.BufferedIOBase interface of the standard library.""" - def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, line_terminator=BINARY_NEWLINE, **kwargs): session = boto3.Session(profile_name=kwargs.pop('profile_name', None)) @@ -130,43 +141,7 @@ def readable(self): return True def seekable(self): - """If False, seek(), tell() and truncate() will raise IOError. - - We offer only seek support, and no truncate support.""" - return True - - def seek(self, offset, whence=START): - """Seek to the specified position. - - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. - - Returns the position after seeking.""" - logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % WHENCE_CHOICES) - - if whence == START: - new_position = offset - elif whence == CURRENT: - new_position = self._current_pos + offset - else: - new_position = self._content_length + offset - new_position = _clamp(new_position, 0, self._content_length) - - logger.debug('new_position: %r', new_position) - self._current_pos = self._raw_reader.position = new_position - self._buffer = b"" - self._eof = self._current_pos == self._content_length - return self._current_pos - - def tell(self): - """Return the current position within the file.""" - return self._current_pos - - def truncate(self, size=None): - """Unsupported.""" - raise io.UnsupportedOperation + return False # # io.BufferedIOBase methods. @@ -268,6 +243,69 @@ def _fill_buffer(self, size): self._eof = True +class SeekableBufferedInputBase(BufferedInputBase): + """Reads bytes from S3. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, + line_terminator=BINARY_NEWLINE, **kwargs): + session = boto3.Session(profile_name=kwargs.pop('profile_name', None)) + s3 = session.resource('s3', **kwargs) + self._object = s3.Object(bucket, key) + self._raw_reader = SeekableRawReader(self._object) + self._content_length = self._object.content_length + self._current_pos = 0 + self._buffer = b'' + self._eof = False + self._buffer_size = buffer_size + self._line_terminator = line_terminator + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def seekable(self): + """If False, seek(), tell() and truncate() will raise IOError. + + We offer only seek support, and no truncate support.""" + return True + + def seek(self, offset, whence=START): + """Seek to the specified position. + + :param int offset: The offset in bytes. + :param int whence: Where the offset is from. + + Returns the position after seeking.""" + logger.debug('seeking to offset: %r whence: %r', offset, whence) + if whence not in WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % WHENCE_CHOICES) + + if whence == START: + new_position = offset + elif whence == CURRENT: + new_position = self._current_pos + offset + else: + new_position = self._content_length + offset + new_position = _clamp(new_position, 0, self._content_length) + + logger.debug('new_position: %r', new_position) + self._current_pos = self._raw_reader.position = new_position + self._buffer = b"" + self._eof = self._current_pos == self._content_length + return self._current_pos + + def tell(self): + """Return the current position within the file.""" + return self._current_pos + + def truncate(self, size=None): + """Unsupported.""" + raise io.UnsupportedOperation + + class BufferedOutputBase(io.BufferedIOBase): """Writes bytes to S3. diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index 1bab5e56..9a9ad89a 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -73,7 +73,7 @@ def test_seek_beginning(self): content = u"hello wořld\nhow are you?".encode('utf8') create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes @@ -88,7 +88,7 @@ def test_seek_start(self): content = u"hello wořld\nhow are you?".encode('utf8') create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') seek = fin.seek(6) self.assertEqual(seek, 6) self.assertEqual(fin.tell(), 6) @@ -99,7 +99,7 @@ def test_seek_current(self): content = u"hello wořld\nhow are you?".encode('utf8') create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') self.assertEqual(fin.read(5), b'hello') seek = fin.seek(1, whence=smart_open.s3.CURRENT) self.assertEqual(seek, 6) @@ -110,7 +110,7 @@ def test_seek_end(self): content = u"hello wořld\nhow are you?".encode('utf8') create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') seek = fin.seek(-4, whence=smart_open.s3.END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') @@ -119,7 +119,7 @@ def test_detect_eof(self): content = u"hello wořld\nhow are you?".encode('utf8') create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 142eee93..c20f3bc6 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -420,6 +420,7 @@ def test_s3_read_moto(self): self.assertEqual(content[14:], smart_open_object.read()) # read the rest + @unittest.skip('seek functionality for S3 currently disabled because of Issue #152') @mock_s3 def test_s3_seek_moto(self): """Does seeking in S3 files work correctly?"""