diff --git a/integration-tests/README.md b/integration-tests/README.md new file mode 100644 index 00000000..c760f2b1 --- /dev/null +++ b/integration-tests/README.md @@ -0,0 +1,47 @@ +This directory contains integration tests for smart_open. +To run the tests, you need read/write access to an S3 bucket. +Also, you need to install py.test and its benchmarks addon: + + pip install pytest pytest_benchmark + +Then, to run the tests, run: + + SMART_OPEN_S3_URL=s3://bucket/smart_open_test py.test integration-tests/test_s3.py + +You may use any key name instead of "smart_open_test". +It does not have to be an existing key. +**The tests will remove the key prior to each test, so be sure the key doesn't contain anything important.** + +The tests will take several minutes to complete. +Each test will run several times to obtain summary statistics such as min, max, mean and median. +This allows us to detect regressions in performance. +Here is some example output (you need a wide screen to get the best of it): + +``` +(smartopen)sergeyich:smart_open misha$ SMART_OPEN_S3_URL=s3://bucket/smart_open_test py.test integration-tests/test_s3.py +=============================================== test session starts ================================================ +platform darwin -- Python 3.6.3, pytest-3.3.0, py-1.5.2, pluggy-0.6.0 +benchmark: 3.1.1 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) +rootdir: /Users/misha/git/smart_open, inifile: +plugins: benchmark-3.1.1 +collected 6 items + +integration-tests/test_s3.py ...... [100%] + + +--------------------------------------------------------------------------------------- benchmark: 6 tests -------------------------------------------------------------------------------------- +Name (time in s) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- +test_s3_readwrite_text 2.7593 (1.0) 3.4935 (1.0) 3.2203 (1.0) 0.3064 (1.0) 3.3202 (1.04) 0.4730 (1.0) 1;0 0.3105 (1.0) 5 1 +test_s3_readwrite_text_gzip 3.0242 (1.10) 4.6782 (1.34) 3.7079 (1.15) 0.8531 (2.78) 3.2001 (1.0) 1.5850 (3.35) 2;0 0.2697 (0.87) 5 1 +test_s3_readwrite_binary 3.0549 (1.11) 3.9062 (1.12) 3.5399 (1.10) 0.3516 (1.15) 3.4721 (1.09) 0.5532 (1.17) 2;0 0.2825 (0.91) 5 1 +test_s3_performance_gz 3.1885 (1.16) 5.2845 (1.51) 3.9298 (1.22) 0.8197 (2.68) 3.6974 (1.16) 0.9693 (2.05) 1;0 0.2545 (0.82) 5 1 +test_s3_readwrite_binary_gzip 3.3756 (1.22) 5.0423 (1.44) 4.1763 (1.30) 0.6381 (2.08) 4.0722 (1.27) 0.9209 (1.95) 2;0 0.2394 (0.77) 5 1 +test_s3_performance 7.6758 (2.78) 29.5266 (8.45) 18.8346 (5.85) 10.3003 (33.62) 21.1854 (6.62) 19.6234 (41.49) 3;0 0.0531 (0.17) 5 1 +------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + +Legend: + Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. + OPS: Operations Per Second, computed as 1 / Mean +============================================ 6 passed in 285.14 seconds ============================================ +``` diff --git a/integration-tests/test_s3.py b/integration-tests/test_s3.py new file mode 100644 index 00000000..9d134e0e --- /dev/null +++ b/integration-tests/test_s3.py @@ -0,0 +1,83 @@ +from __future__ import unicode_literals +import io +import os +import subprocess + +import smart_open + +_S3_URL = os.environ.get('SMART_OPEN_S3_URL') +assert _S3_URL is not None, 'please set the SMART_OPEN_S3_URL environment variable' + + +def initialize_bucket(): + subprocess.check_call(['aws', 's3', 'rm', '--recursive', _S3_URL]) + + +def write_read(key, content, write_mode, read_mode): + with smart_open.smart_open(key, write_mode) as fout: + fout.write(content) + with smart_open.smart_open(key, read_mode) as fin: + actual = fin.read() + return actual + + +def test_s3_readwrite_text(benchmark): + initialize_bucket() + + key = _S3_URL + '/sanity.txt' + text = 'с гранатою в кармане, с чекою в руке' + actual = benchmark(write_read, key, text, 'w', 'r') + assert actual == text + + +def test_s3_readwrite_text_gzip(benchmark): + initialize_bucket() + + key = _S3_URL + '/sanity.txt.gz' + text = 'не чайки здесь запели на знакомом языке' + actual = benchmark(write_read, key, text, 'w', 'r') + assert actual == text + + +def test_s3_readwrite_binary(benchmark): + initialize_bucket() + + key = _S3_URL + '/sanity.txt' + binary = b'this is a test' + actual = benchmark(write_read, key, binary, 'wb', 'rb') + assert actual == binary + + +def test_s3_readwrite_binary_gzip(benchmark): + initialize_bucket() + + key = _S3_URL + '/sanity.txt.gz' + binary = b'this is a test' + actual = benchmark(write_read, key, binary, 'wb', 'rb') + assert actual == binary + + +def test_s3_performance(benchmark): + initialize_bucket() + + one_megabyte = io.BytesIO() + for _ in range(1024*128): + one_megabyte.write(b'01234567') + one_megabyte = one_megabyte.getvalue() + + key = _S3_URL + '/performance.txt' + actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') + assert actual == one_megabyte + + +def test_s3_performance_gz(benchmark): + initialize_bucket() + + one_megabyte = io.BytesIO() + for _ in range(1024*128): + one_megabyte.write(b'01234567') + one_megabyte = one_megabyte.getvalue() + + key = _S3_URL + '/performance.txt.gz' + actual = benchmark(write_read, key, one_megabyte, 'wb', 'rb') + assert actual == one_megabyte diff --git a/smart_open/s3.py b/smart_open/s3.py index 420afbdf..a76dab50 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Implements file-like objects for reading and writing from/to S3.""" import boto3 +import botocore.client import io import logging @@ -28,6 +29,9 @@ MODES = (READ, READ_BINARY, WRITE, WRITE_BINARY) """Allowed I/O modes for working with S3.""" +BINARY_NEWLINE = b'\n' +DEFAULT_BUFFER_SIZE = 256 * 1024 + def _range_string(start, stop=None): # @@ -47,7 +51,6 @@ def open(bucket_id, key_id, mode, **kwargs): if mode not in MODES: raise NotImplementedError('bad mode: %r expected one of %r' % (mode, MODES)) - buffer_size = kwargs.pop("buffer_size", io.DEFAULT_BUFFER_SIZE) encoding = kwargs.pop("encoding", "utf-8") errors = kwargs.pop("errors", None) newline = kwargs.pop("newline", None) @@ -55,7 +58,7 @@ def open(bucket_id, key_id, mode, **kwargs): s3_min_part_size = kwargs.pop("s3_min_part_size", DEFAULT_MIN_PART_SIZE) if mode in (READ, READ_BINARY): - fileobj = BufferedInputBase(bucket_id, key_id, **kwargs) + fileobj = SeekableBufferedInputBase(bucket_id, key_id, **kwargs) elif mode in (WRITE, WRITE_BINARY): fileobj = BufferedOutputBase(bucket_id, key_id, min_part_size=s3_min_part_size, **kwargs) else: @@ -72,6 +75,21 @@ def open(bucket_id, key_id, mode, **kwargs): class RawReader(object): """Read an S3 object.""" + def __init__(self, s3_object): + self.position = 0 + self._object = s3_object + self._body = s3_object.get()['Body'] + + def read(self, size=-1): + if size == -1: + return self._body.read() + return self._body.read(size) + + +class SeekableRawReader(object): + """Read an S3 object. + + Support seeking around, but is slower than RawReader.""" def __init__(self, s3_object): self.position = 0 self._object = s3_object @@ -92,11 +110,8 @@ def read(self, size=-1): class BufferedInputBase(io.BufferedIOBase): - """Reads bytes from S3. - - Implements the io.BufferedIOBase interface of the standard library.""" - - def __init__(self, bucket, key, **kwargs): + def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, + line_terminator=BINARY_NEWLINE, **kwargs): session = boto3.Session(profile_name=kwargs.pop('profile_name', None)) s3 = session.resource('s3', **kwargs) self._object = s3.Object(bucket, key) @@ -105,6 +120,8 @@ def __init__(self, bucket, key, **kwargs): self._current_pos = 0 self._buffer = b'' self._eof = False + self._buffer_size = buffer_size + self._line_terminator = line_terminator # # This member is part of the io.BufferedIOBase interface. @@ -124,43 +141,7 @@ def readable(self): return True def seekable(self): - """If False, seek(), tell() and truncate() will raise IOError. - - We offer only seek support, and no truncate support.""" - return True - - def seek(self, offset, whence=START): - """Seek to the specified position. - - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. - - Returns the position after seeking.""" - logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % WHENCE_CHOICES) - - if whence == START: - new_position = offset - elif whence == CURRENT: - new_position = self._current_pos + offset - else: - new_position = self._content_length + offset - new_position = _clamp(new_position, 0, self._content_length) - - logger.debug('new_position: %r', new_position) - self._current_pos = self._raw_reader.position = new_position - self._buffer = b"" - self._eof = self._current_pos == self._content_length - return self._current_pos - - def tell(self): - """Return the current position within the file.""" - return self._current_pos - - def truncate(self, size=None): - """Unsupported.""" - raise io.UnsupportedOperation + return False # # io.BufferedIOBase methods. @@ -195,14 +176,7 @@ def read(self, size=-1): # Fill our buffer to the required size. # # logger.debug('filling %r byte-long buffer up to %r bytes', len(self._buffer), size) - while len(self._buffer) < size and not self._eof: - raw = self._raw_reader.read(size=io.DEFAULT_BUFFER_SIZE) - if len(raw): - self._buffer += raw - else: - logger.debug('reached EOF while filling buffer') - self._eof = True - + self._fill_buffer(size) return self._read_from_buffer(size) def read1(self, size=-1): @@ -218,6 +192,30 @@ def readinto(self, b): b[:len(data)] = data return len(data) + def readline(self, limit=-1): + """Read up to and including the next newline. Returns the bytes read.""" + if limit != -1: + raise NotImplementedError('limits other than -1 not implemented yet') + the_line = io.BytesIO() + while not (self._eof and len(self._buffer) == 0): + # + # In the worst case, we're reading self._buffer twice here, once in + # the if condition, and once when calling index. + # + # This is sub-optimal, but better than the alternative: wrapping + # .index in a try..except, because that is slower. + # + if self._line_terminator in self._buffer: + next_newline = self._buffer.index(self._line_terminator) + the_line.write(self._buffer[:next_newline + 1]) + self._buffer = self._buffer[next_newline + 1:] + break + else: + the_line.write(self._buffer) + self._buffer = b'' + self._fill_buffer(self._buffer_size) + return the_line.getvalue() + def terminate(self): """Do nothing.""" pass @@ -235,6 +233,78 @@ def _read_from_buffer(self, size): # logger.debug('part: %r', part) return part + def _fill_buffer(self, size): + while len(self._buffer) < size and not self._eof: + raw = self._raw_reader.read(size=self._buffer_size) + if len(raw): + self._buffer += raw + else: + logger.debug('reached EOF while filling buffer') + self._eof = True + + +class SeekableBufferedInputBase(BufferedInputBase): + """Reads bytes from S3. + + Implements the io.BufferedIOBase interface of the standard library.""" + + def __init__(self, bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, + line_terminator=BINARY_NEWLINE, **kwargs): + session = boto3.Session(profile_name=kwargs.pop('profile_name', None)) + s3 = session.resource('s3', **kwargs) + self._object = s3.Object(bucket, key) + self._raw_reader = SeekableRawReader(self._object) + self._content_length = self._object.content_length + self._current_pos = 0 + self._buffer = b'' + self._eof = False + self._buffer_size = buffer_size + self._line_terminator = line_terminator + + # + # This member is part of the io.BufferedIOBase interface. + # + self.raw = None + + def seekable(self): + """If False, seek(), tell() and truncate() will raise IOError. + + We offer only seek support, and no truncate support.""" + return True + + def seek(self, offset, whence=START): + """Seek to the specified position. + + :param int offset: The offset in bytes. + :param int whence: Where the offset is from. + + Returns the position after seeking.""" + logger.debug('seeking to offset: %r whence: %r', offset, whence) + if whence not in WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % WHENCE_CHOICES) + + if whence == START: + new_position = offset + elif whence == CURRENT: + new_position = self._current_pos + offset + else: + new_position = self._content_length + offset + new_position = _clamp(new_position, 0, self._content_length) + + logger.debug('new_position: %r', new_position) + self._current_pos = self._raw_reader.position = new_position + self._buffer = b"" + self._eof = self._current_pos == self._content_length + return self._current_pos + + def tell(self): + """Return the current position within the file.""" + return self._current_pos + + def truncate(self, size=None): + """Unsupported.""" + raise io.UnsupportedOperation + class BufferedOutputBase(io.BufferedIOBase): """Writes bytes to S3. @@ -252,7 +322,10 @@ def __init__(self, bucket, key, min_part_size=DEFAULT_MIN_PART_SIZE, **kwargs): # # https://stackoverflow.com/questions/26871884/how-can-i-easily-determine-if-a-boto-3-s3-bucket-resource-exists # - s3.create_bucket(Bucket=bucket) + try: + s3.meta.client.head_bucket(Bucket=bucket) + except botocore.client.ClientError: + raise ValueError('the bucket %r does not exist, or is forbidden for access' % bucket) self._object = s3.Object(bucket, key) self._min_part_size = min_part_size self._mp = self._object.initiate_multipart_upload() diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index fa631471..f0f638e7 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -41,7 +41,6 @@ if IS_PY2: import cStringIO as StringIO -import contextlib if sys.version_info[0] == 2: import httplib @@ -83,6 +82,8 @@ "Re-open the file without specifying an encoding to suppress this warning." ) +DEFAULT_ERRORS = 'strict' + def smart_open(uri, mode="rb", **kw): """ @@ -170,7 +171,9 @@ def smart_open(uri, mode="rb", **kw): if parsed_uri.scheme in ("file", ): # local files -- both read & write supported # compression, if any, is determined by the filename extension (.gz, .bz2) - return file_smart_open(parsed_uri.uri_path, mode, encoding=kw.pop('encoding', None)) + encoding = kw.pop('encoding', None) + errors = kw.pop('errors', DEFAULT_ERRORS) + return file_smart_open(parsed_uri.uri_path, mode, encoding=encoding, errors=errors) elif parsed_uri.scheme in ("s3", "s3n", 's3u'): return s3_open_uri(parsed_uri, mode, **kw) elif parsed_uri.scheme in ("hdfs", ): @@ -238,12 +241,12 @@ def s3_open_uri(parsed_uri, mode, **kwargs): # Codecs work on a byte-level, so the underlying S3 object should # always be reading bytes. # - if codec and mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY): + if mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY): s3_mode = smart_open_s3.READ_BINARY - elif codec and mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY): + elif mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY): s3_mode = smart_open_s3.WRITE_BINARY else: - s3_mode = mode + raise NotImplementedError('mode %r not implemented for S3' % mode) # # TODO: I'm not sure how to handle this with boto3. Any ideas? @@ -252,8 +255,12 @@ def s3_open_uri(parsed_uri, mode, **kwargs): # # _setup_unsecured_mode() + encoding = kwargs.get('encoding') + errors = kwargs.get('errors', DEFAULT_ERRORS) fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs) - return _CODECS[codec](fobj, mode) + decompressed_fobj = _CODECS[codec](fobj, mode) + decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) + return decoded_fobj def _setup_unsecured_mode(parsed_uri, kwargs): @@ -285,16 +292,20 @@ def s3_open_key(key, mode, **kwargs): # Codecs work on a byte-level, so the underlying S3 object should # always be reading bytes. # - if codec and mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY): + if mode in (smart_open_s3.READ, smart_open_s3.READ_BINARY): s3_mode = smart_open_s3.READ_BINARY - elif codec and mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY): + elif mode in (smart_open_s3.WRITE, smart_open_s3.WRITE_BINARY): s3_mode = smart_open_s3.WRITE_BINARY else: - s3_mode = mode + raise NotImplementedError('mode %r not implemented for S3' % mode) logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode) + encoding = kwargs.get('encoding') + errors = kwargs.get('errors', DEFAULT_ERRORS) fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs) - return _CODECS[codec](fobj, mode) + decompressed_fobj = _CODECS[codec](fobj, mode) + decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) + return decoded_fobj def _detect_codec(filename): @@ -304,7 +315,7 @@ def _detect_codec(filename): def _wrap_gzip(fileobj, mode): - return contextlib.closing(gzip.GzipFile(fileobj=fileobj, mode=mode)) + return gzip.GzipFile(fileobj=fileobj, mode=mode) def _wrap_none(fileobj, mode): @@ -595,7 +606,7 @@ def compression_wrapper(file_obj, filename, mode): return file_obj -def encoding_wrapper(fileobj, mode, encoding=None): +def encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): """Decode bytes into text, if necessary. If mode specifies binary access, does nothing, unless the encoding is @@ -603,7 +614,8 @@ def encoding_wrapper(fileobj, mode, encoding=None): :arg fileobj: must quack like a filehandle object. :arg str mode: is the mode which was originally requested by the user. - :arg encoding: The text encoding to use. If mode is binary, overrides mode. + :arg str encoding: The text encoding to use. If mode is binary, overrides mode. + :arg str errors: The method to use when handling encoding/decoding errors. :returns: a file object """ logger.debug('encoding_wrapper: %r', locals()) @@ -627,10 +639,10 @@ def encoding_wrapper(fileobj, mode, encoding=None): decoder = codecs.getreader(encoding) else: decoder = codecs.getwriter(encoding) - return decoder(fileobj) + return decoder(fileobj, errors=errors) -def file_smart_open(fname, mode='rb', encoding=None): +def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS): """ Stream from/to local filesystem, transparently (de)compressing gzip and bz2 files if necessary. @@ -638,6 +650,7 @@ def file_smart_open(fname, mode='rb', encoding=None): :arg str fname: The path to the file to open. :arg str mode: The mode in which to open the file. :arg str encoding: The text encoding to use. + :arg str errors: The method to use when handling encoding/decoding errors. :returns: A file object """ # @@ -657,7 +670,7 @@ def file_smart_open(fname, mode='rb', encoding=None): raw_mode = mode raw_fobj = open(fname, raw_mode) decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode) - decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding) + decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) return decoded_fobj diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index 3a9c12c3..95a96feb 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import contextlib import logging import gzip import io @@ -9,7 +8,7 @@ else: import unittest -import boto +import boto3 import moto import smart_open @@ -21,20 +20,16 @@ def create_bucket_and_key(bucket_name='mybucket', key_name='mykey', contents=None): # fake connection, bucket and key _LOGGER.debug('%r', locals()) - conn = boto.connect_s3() - conn.create_bucket(bucket_name) - mybucket = conn.get_bucket(bucket_name) - mykey = boto.s3.key.Key() - mykey.name = key_name - mykey.bucket = mybucket + s3 = boto3.resource('s3') + mybucket = s3.create_bucket(Bucket=bucket_name) + mykey = s3.Object(bucket_name, key_name) if contents is not None: - _LOGGER.debug('len(contents): %r', len(contents)) - mykey.set_contents_from_string(contents) + mykey.put(Body=contents) return mybucket, mykey @moto.mock_s3 -class BufferedInputBaseTest(unittest.TestCase): +class SeekableBufferedInputBaseTest(unittest.TestCase): def setUp(self): # lower the multipart upload size, to speed up these tests self.old_min_part_size = smart_open.s3.DEFAULT_MIN_PART_SIZE @@ -47,28 +42,28 @@ def test_iter(self): """Are S3 files iterated over correctly?""" # a list of strings to test with expected = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=expected) + create_bucket_and_key(contents=expected) # connect to fake s3 and read from the fake key we filled above - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) def test_iter_context_manager(self): # same thing but using a context manager expected = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=expected) - with smart_open.s3.BufferedInputBase('mybucket', 'mykey') as fin: + create_bucket_and_key(contents=expected) + with smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') as fin: output = [line.rstrip(b'\n') for line in fin] self.assertEqual(output, expected.split(b'\n')) def test_read(self): """Are S3 files read correctly?""" content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) _LOGGER.debug('content: %r len: %r', content, len(content)) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes self.assertEqual(content[14:], fin.read()) # read the rest @@ -76,9 +71,9 @@ def test_read(self): def test_seek_beginning(self): """Does seeking to the beginning of S3 files work correctly?""" content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') self.assertEqual(content[:6], fin.read(6)) self.assertEqual(content[6:14], fin.read(8)) # ř is 2 bytes @@ -91,9 +86,9 @@ def test_seek_beginning(self): def test_seek_start(self): """Does seeking from the start of S3 files work correctly?""" content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') seek = fin.seek(6) self.assertEqual(seek, 6) self.assertEqual(fin.tell(), 6) @@ -102,9 +97,9 @@ def test_seek_start(self): def test_seek_current(self): """Does seeking from the middle of S3 files work correctly?""" content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') self.assertEqual(fin.read(5), b'hello') seek = fin.seek(1, whence=smart_open.s3.CURRENT) self.assertEqual(seek, 6) @@ -113,18 +108,18 @@ def test_seek_current(self): def test_seek_end(self): """Does seeking from the end of S3 files work correctly?""" content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') seek = fin.seek(-4, whence=smart_open.s3.END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') def test_detect_eof(self): content = u"hello wořld\nhow are you?".encode('utf8') - bucket, key = create_bucket_and_key(contents=content) + create_bucket_and_key(contents=content) - fin = smart_open.s3.BufferedInputBase('mybucket', 'mykey') + fin = smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) @@ -135,30 +130,50 @@ def test_read_gzip(self): expected = u'раcцветали яблони и груши, поплыли туманы над рекой...'.encode('utf-8') buf = io.BytesIO() buf.close = lambda: None # keep buffer open so that we can .getvalue() - with contextlib.closing(gzip.GzipFile(fileobj=buf, mode='w')) as zipfile: + with gzip.GzipFile(fileobj=buf, mode='w') as zipfile: zipfile.write(expected) - bucket, key = create_bucket_and_key(contents=buf.getvalue()) + create_bucket_and_key(contents=buf.getvalue()) # # Make sure we're reading things correctly. # - with smart_open.s3.BufferedInputBase('mybucket', 'mykey') as fin: + with smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') as fin: self.assertEqual(fin.read(), buf.getvalue()) # # Make sure the buffer we wrote is legitimate gzip. # sanity_buf = io.BytesIO(buf.getvalue()) - with contextlib.closing(gzip.GzipFile(fileobj=sanity_buf)) as zipfile: + with gzip.GzipFile(fileobj=sanity_buf) as zipfile: self.assertEqual(zipfile.read(), expected) _LOGGER.debug('starting actual test') - with smart_open.s3.BufferedInputBase('mybucket', 'mykey') as fin: - with contextlib.closing(gzip.GzipFile(fileobj=fin)) as zipfile: + with smart_open.s3.SeekableBufferedInputBase('mybucket', 'mykey') as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() self.assertEqual(expected, actual) + def test_readline(self): + content = b'englishman\nin\nnew\nyork\n' + create_bucket_and_key(contents=content) + + with smart_open.s3.BufferedInputBase('mybucket', 'mykey') as fin: + actual = list(fin) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + + def test_readline_tiny_buffer(self): + content = b'englishman\nin\nnew\nyork\n' + create_bucket_and_key(contents=content) + + with smart_open.s3.BufferedInputBase('mybucket', 'mykey', buffer_size=8) as fin: + actual = list(fin) + + expected = [b'englishman\n', b'in\n', b'new\n', b'york\n'] + self.assertEqual(expected, actual) + @moto.mock_s3 class BufferedOutputBaseTest(unittest.TestCase): @@ -168,7 +183,7 @@ class BufferedOutputBaseTest(unittest.TestCase): """ def test_write_01(self): """Does writing into s3 work correctly?""" - mybucket, mykey = create_bucket_and_key() + create_bucket_and_key() test_string = u"žluťoučký koníček".encode('utf8') # write into key @@ -182,7 +197,7 @@ def test_write_01(self): def test_write_01a(self): """Does s3 write fail on incorrect input?""" - mybucket, mykey = create_bucket_and_key() + create_bucket_and_key() try: with smart_open.s3.BufferedOutputBase('mybucket', 'writekey') as fin: @@ -194,7 +209,7 @@ def test_write_01a(self): def test_write_02(self): """Does s3 write unicode-utf8 conversion work?""" - mybucket, mykey = create_bucket_and_key() + create_bucket_and_key() smart_open_write = smart_open.s3.BufferedOutputBase('mybucket', 'writekey') smart_open_write.tell() @@ -205,7 +220,7 @@ def test_write_02(self): def test_write_03(self): """Does s3 multipart chunking work correctly?""" - mybucket, mykey = create_bucket_and_key() + create_bucket_and_key() # write smart_open_write = smart_open.s3.BufferedOutputBase( @@ -245,11 +260,11 @@ def test_gzip(self): expected = u'а не спеть ли мне песню... о любви'.encode('utf-8') with smart_open.s3.BufferedOutputBase('mybucket', 'writekey') as fout: - with contextlib.closing(gzip.GzipFile(fileobj=fout, mode='w')) as zipfile: + with gzip.GzipFile(fileobj=fout, mode='w') as zipfile: zipfile.write(expected) - with smart_open.s3.BufferedInputBase('mybucket', 'writekey') as fin: - with contextlib.closing(gzip.GzipFile(fileobj=fin)) as zipfile: + with smart_open.s3.SeekableBufferedInputBase('mybucket', 'writekey') as fin: + with gzip.GzipFile(fileobj=fin) as zipfile: actual = zipfile.read() self.assertEqual(expected, actual) @@ -268,6 +283,12 @@ def test_binary_iterator(self): actual = [line.rstrip() for line in fin] self.assertEqual(expected, actual) + def test_nonexisting_bucket(self): + expected = u"выйду ночью в поле с конём".encode('utf-8') + with self.assertRaises(ValueError): + with smart_open.s3.open('thisbucketdoesntexist', 'mykey', 'wb') as fout: + fout.write(expected) + class ClampTest(unittest.TestCase): def test(self): diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 8f866cb0..c20f3bc6 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -293,29 +293,29 @@ def test_file(self, mock_smart_open): smart_open_object = smart_open.smart_open(prefix+full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None) + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') full_path = '/tmp/test#hash##more.txt' read_mode = "rb" smart_open_object = smart_open.smart_open(prefix+full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None) + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') full_path = 'aa#aa' read_mode = "rb" smart_open_object = smart_open.smart_open(full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None) + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') short_path = "~/tmp/test.txt" full_path = os.path.expanduser(short_path) - smart_open_object = smart_open.smart_open(prefix+short_path, read_mode) + smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict') smart_open_object.__iter__() # called with the correct expanded path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None) + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') # couldn't find any project for mocking up HDFS data # TODO: we want to test also a content of the files, not just fnc call params @@ -420,6 +420,7 @@ def test_s3_read_moto(self): self.assertEqual(content[14:], smart_open_object.read()) # read the rest + @unittest.skip('seek functionality for S3 currently disabled because of Issue #152') @mock_s3 def test_s3_seek_moto(self): """Does seeking in S3 files work correctly?""" @@ -485,15 +486,15 @@ def test_file_mode_mock(self, mock_file, mock_boto): # correct read modes smart_open.smart_open("blah", "r") - mock_file.assert_called_with("blah", "r", encoding=None) + mock_file.assert_called_with("blah", "r", encoding=None, errors='strict') smart_open.smart_open("blah", "rb") - mock_file.assert_called_with("blah", "rb", encoding=None) + mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict') short_path = "~/blah" full_path = os.path.expanduser(short_path) smart_open.smart_open(short_path, "rb") - mock_file.assert_called_with(full_path, "rb", encoding=None) + mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict') # correct write modes, incorrect scheme self.assertRaises(NotImplementedError, smart_open.smart_open, "hdfs:///blah.txt", "wb+") @@ -502,16 +503,16 @@ def test_file_mode_mock(self, mock_file, mock_boto): # correct write mode, correct file:// URI smart_open.smart_open("blah", "w") - mock_file.assert_called_with("blah", "w", encoding=None) + mock_file.assert_called_with("blah", "w", encoding=None, errors='strict') smart_open.smart_open("file:///some/file.txt", "wb") - mock_file.assert_called_with("/some/file.txt", "wb", encoding=None) + mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict') smart_open.smart_open("file:///some/file.txt", "wb+") - mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None) + mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict') smart_open.smart_open("file:///some/file.txt", "w+") - mock_file.assert_called_with("/some/file.txt", "w+", encoding=None) + mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict') @mock.patch('boto3.Session') def test_s3_mode_mock(self, mock_session): @@ -595,6 +596,32 @@ def test_s3_modes_moto(self): self.assertEqual(output, [test_string]) + @mock_s3 + def test_write_bad_encoding_strict(self): + """Should abort on encoding error.""" + text = u'欲しい気持ちが成長しすぎて' + + with self.assertRaises(UnicodeEncodeError): + with tempfile.NamedTemporaryFile('wb', delete=True) as infile: + with smart_open.smart_open(infile.name, 'w', encoding='koi8-r', + errors='strict') as fout: + fout.write(text) + + @mock_s3 + def test_write_bad_encoding_replace(self): + """Should replace characters that failed to encode.""" + text = u'欲しい気持ちが成長しすぎて' + expected = u'?' * len(text) + + with tempfile.NamedTemporaryFile('wb', delete=True) as infile: + with smart_open.smart_open(infile.name, 'w', encoding='koi8-r', + errors='replace') as fout: + fout.write(text) + with smart_open.smart_open(infile.name, 'r', encoding='koi8-r') as fin: + actual = fin.read() + + self.assertEqual(expected, actual) + class WebHdfsWriteTest(unittest.TestCase): """ @@ -951,12 +978,15 @@ def test_r(self): text = u"физкульт-привет!" key.set_contents_from_string(text.encode("utf-8")) - with smart_open.s3_open_key(key, "r") as fin: - self.assertEqual(fin.read(), u"физкульт-привет!") + with smart_open.s3_open_key(key, "rb") as fin: + self.assertEqual(fin.read(), text.encode('utf-8')) + + with smart_open.s3_open_key(key, "r", encoding='utf-8') as fin: + self.assertEqual(fin.read(), text) parsed_uri = smart_open.ParseUri("s3://bucket/key") - with smart_open.s3_open_uri(parsed_uri, "r") as fin: - self.assertEqual(fin.read(), u"физкульт-привет!") + with smart_open.s3_open_uri(parsed_uri, "r", encoding='utf-8') as fin: + self.assertEqual(fin.read(), text) def test_bad_mode(self): """Bad mode should raise and exception.""" @@ -1042,7 +1072,7 @@ def test_gzip_read_mode(self): def test_read_encoding(self): """Should open the file with the correct encoding, explicit text read.""" conn = boto.connect_s3() - conn.create_bucket('test-bucket') + conn.create_bucket('bucket') key = "s3://bucket/key.txt" text = u'это знала ева, это знал адам, колеса любви едут прямо по нам' with smart_open.smart_open(key, 'wb') as fout: @@ -1055,7 +1085,7 @@ def test_read_encoding(self): def test_read_encoding_implicit_text(self): """Should open the file with the correct encoding, implicit text read.""" conn = boto.connect_s3() - conn.create_bucket('test-bucket') + conn.create_bucket('bucket') key = "s3://bucket/key.txt" text = u'это знала ева, это знал адам, колеса любви едут прямо по нам' with smart_open.smart_open(key, 'wb') as fout: @@ -1068,7 +1098,7 @@ def test_read_encoding_implicit_text(self): def test_write_encoding(self): """Should open the file for writing with the correct encoding.""" conn = boto.connect_s3() - conn.create_bucket('test-bucket') + conn.create_bucket('bucket') key = "s3://bucket/key.txt" text = u'какая боль, какая боль, аргентина - ямайка, 5-0' @@ -1078,6 +1108,61 @@ def test_write_encoding(self): actual = fin.read() self.assertEqual(text, actual) + @mock_s3 + def test_write_bad_encoding_strict(self): + """Should open the file for writing with the correct encoding.""" + conn = boto.connect_s3() + conn.create_bucket('bucket') + key = "s3://bucket/key.txt" + text = u'欲しい気持ちが成長しすぎて' + + with self.assertRaises(UnicodeEncodeError): + with smart_open.smart_open(key, 'w', encoding='koi8-r', errors='strict') as fout: + fout.write(text) + + @mock_s3 + def test_write_bad_encoding_replace(self): + """Should open the file for writing with the correct encoding.""" + conn = boto.connect_s3() + conn.create_bucket('bucket') + key = "s3://bucket/key.txt" + text = u'欲しい気持ちが成長しすぎて' + expected = u'?' * len(text) + + with smart_open.smart_open(key, 'w', encoding='koi8-r', errors='replace') as fout: + fout.write(text) + with smart_open.smart_open(key, encoding='koi8-r') as fin: + actual = fin.read() + self.assertEqual(expected, actual) + + @mock_s3 + def test_write_text_gzip(self): + """Should open the file for writing with the correct encoding.""" + conn = boto.connect_s3() + conn.create_bucket('bucket') + key = "s3://bucket/key.txt.gz" + text = u'какая боль, какая боль, аргентина - ямайка, 5-0' + + with smart_open.smart_open(key, 'w', encoding='utf-8') as fout: + fout.write(text) + with smart_open.smart_open(key, 'r', encoding='utf-8') as fin: + actual = fin.read() + self.assertEqual(text, actual) + + @mock_s3 + def test_write_text_gzip_key(self): + """Should open the boto S3 key for writing with the correct encoding.""" + conn = boto.connect_s3() + mybucket = conn.create_bucket('bucket') + mykey = boto.s3.key.Key(mybucket, 'key.txt.gz') + text = u'какая боль, какая боль, аргентина - ямайка, 5-0' + + with smart_open.smart_open(mykey, 'w', encoding='utf-8') as fout: + fout.write(text) + with smart_open.smart_open(mykey, 'r', encoding='utf-8') as fin: + actual = fin.read() + self.assertEqual(text, actual) + if __name__ == '__main__': logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)