From e4c0da9e8154e02b56d08911e3ab6d07966f0c78 Mon Sep 17 00:00:00 2001 From: shmuelamar Date: Thu, 1 Mar 2018 21:52:21 +0200 Subject: [PATCH] add support for local file ignore_extension --- README.rst | 5 ++ smart_open/smart_open_lib.py | 76 +++++++-------------- smart_open/tests/test_smart_open.py | 100 ++++++++++++++++++++++++---- 3 files changed, 118 insertions(+), 63 deletions(-) diff --git a/README.rst b/README.rst index 64cb8f28..bd31e570 100644 --- a/README.rst +++ b/README.rst @@ -73,6 +73,11 @@ It is well tested (using `moto `_), well documen >>> with smart_open.smart_open('/home/radim/foo.txt.bz2', 'wb') as fout: ... fout.write("some content\n") + >>> # ignore_extension flag can be used for overriding compression extension: + >>> with smart_open.smart_open('/home/radim/rawfile.gz', ignore_extension=True) as fout: + ... print 'file md5: {}'.format(hashlib.md5(fout.read()).hexdigest()) + + Since going over all (or select) keys in an S3 bucket is a very common operation, there's also an extra method ``smart_open.s3_iter_bucket()`` that does this efficiently, **processing the bucket keys in parallel** (using multiprocessing): diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index 52f5c97c..736b48b7 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -173,7 +173,11 @@ def smart_open(uri, mode="rb", **kw): # compression, if any, is determined by the filename extension (.gz, .bz2) encoding = kw.pop('encoding', None) errors = kw.pop('errors', DEFAULT_ERRORS) - return file_smart_open(parsed_uri.uri_path, mode, encoding=encoding, errors=errors) + ignore_extension = kw.pop('ignore_extension', False) + return file_smart_open( + parsed_uri.uri_path, mode, encoding=encoding, errors=errors, + ignore_extension=ignore_extension, + ) elif parsed_uri.scheme in ("s3", "s3n", 's3u'): return s3_open_uri(parsed_uri, mode, **kw) elif parsed_uri.scheme in ("hdfs", ): @@ -215,7 +219,7 @@ def smart_open(uri, mode="rb", **kw): raise TypeError('don\'t know how to handle uri %s' % repr(uri)) -def s3_open_uri(parsed_uri, mode, **kwargs): +def s3_open_uri(parsed_uri, mode, ignore_extension=False, **kwargs): logger.debug('%r', locals()) if parsed_uri.access_id is not None: kwargs['aws_access_key_id'] = parsed_uri.access_id @@ -227,16 +231,6 @@ def s3_open_uri(parsed_uri, mode, **kwargs): if host is not None: kwargs['endpoint_url'] = 'http://' + host - # - # TODO: this is the wrong place to handle ignore_extension. - # It should happen at the highest level in the smart_open function, because - # it influences other file systems as well, not just S3. - # - if kwargs.pop("ignore_extension", False): - codec = None - else: - codec = _detect_codec(parsed_uri.key_id) - # # Codecs work on a byte-level, so the underlying S3 object should # always be reading bytes. @@ -258,7 +252,7 @@ def s3_open_uri(parsed_uri, mode, **kwargs): encoding = kwargs.get('encoding') errors = kwargs.get('errors', DEFAULT_ERRORS) fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs) - decompressed_fobj = _CODECS[codec](fobj, mode) + decompressed_fobj = compression_wrapper(fobj, parsed_uri.key_id, mode, ignore_extension) decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) return decoded_fobj @@ -274,7 +268,7 @@ def _setup_unsecured_mode(parsed_uri, kwargs): kwargs['calling_format'] = boto.s3.connection.OrdinaryCallingFormat() -def s3_open_key(key, mode, **kwargs): +def s3_open_key(key, mode, ignore_extension=False, **kwargs): logger.debug('%r', locals()) # # TODO: handle boto3 keys as well @@ -283,11 +277,6 @@ def s3_open_key(key, mode, **kwargs): if host is not None: kwargs['endpoint_url'] = 'http://' + host - if kwargs.pop("ignore_extension", False): - codec = None - else: - codec = _detect_codec(key.name) - # # Codecs work on a byte-level, so the underlying S3 object should # always be reading bytes. @@ -299,38 +288,15 @@ def s3_open_key(key, mode, **kwargs): else: raise NotImplementedError('mode %r not implemented for S3' % mode) - logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode) + logging.debug('mode: %r s3_mode: %r', mode, s3_mode) encoding = kwargs.get('encoding') errors = kwargs.get('errors', DEFAULT_ERRORS) fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs) - decompressed_fobj = _CODECS[codec](fobj, mode) + decompressed_fobj = compression_wrapper(fobj, key.name, mode, ignore_extension) decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) return decoded_fobj -def _detect_codec(filename): - if filename.endswith(".gz"): - return 'gzip' - return None - - -def _wrap_gzip(fileobj, mode): - return gzip.GzipFile(fileobj=fileobj, mode=mode) - - -def _wrap_none(fileobj, mode): - return fileobj - - -_CODECS = { - None: _wrap_none, - 'gzip': _wrap_gzip, - # - # TODO: add support for other codecs here. - # -} - - class ParseUri(object): """ Parse the given URI. @@ -586,17 +552,23 @@ def close(self): fileobj.close() -def compression_wrapper(file_obj, filename, mode): +def compression_wrapper(file_obj, filename, mode, ignore_extension=False): """ This function will wrap the file_obj with an appropriate [de]compression mechanism based on the extension of the filename. + if `ignore_extension` is True, the original file_obj is always returned + without any [de]compression mechanism added. + file_obj must either be a filehandle object, or a class which behaves - like one. + like one. If the filename extension isn't recognized, will simply return the original file_obj. """ + if ignore_extension: + return file_obj + _, ext = os.path.splitext(filename) if ext == '.bz2': return ClosingBZ2File(file_obj, mode) @@ -642,7 +614,7 @@ def encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS): return decoder(fileobj, errors=errors) -def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS): +def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS, ignore_extension=False): """ Stream from/to local filesystem, transparently (de)compressing gzip and bz2 files if necessary. @@ -651,6 +623,8 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS): :arg str mode: The mode in which to open the file. :arg str encoding: The text encoding to use. :arg str errors: The method to use when handling encoding/decoding errors. + :arg bool ignore_extension: If True, skips the auto detect compression + by file extension. :returns: A file object """ # @@ -669,7 +643,7 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS): except KeyError: raw_mode = mode raw_fobj = open(fname, raw_mode) - decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode) + decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode, ignore_extension) decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors) return decoded_fobj @@ -774,7 +748,7 @@ def __exit__(self, *args, **kwargs): self.response.close() -def HttpOpenRead(parsed_uri, mode='r', **kwargs): +def HttpOpenRead(parsed_uri, mode='r', ignore_extension=False, **kwargs): if parsed_uri.scheme not in ('http', 'https'): raise TypeError("can only process http/https urls") if mode not in ('r', 'rb'): @@ -789,9 +763,9 @@ def HttpOpenRead(parsed_uri, mode='r', **kwargs): if fname.endswith('.gz'): # Gzip needs a seek-able filehandle, so we need to buffer it. buffer = make_closing(io.BytesIO)(response.binary_content()) - return compression_wrapper(buffer, fname, mode) + return compression_wrapper(buffer, fname, mode, ignore_extension) else: - return compression_wrapper(response, fname, mode) + return compression_wrapper(response, fname, mode, ignore_extension) class WebHdfsOpenWrite(object): diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index e37a6546..19ed0e53 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -143,6 +143,21 @@ def test_http_gz(self): # decompress the gzip and get the same md5 hash self.assertEqual(m.hexdigest(), '18473e60f8c7c98d29d65bf805736a0d') + @responses.activate + def test_http_gz_ignore_extension(self): + """Can open gzip as raw via http?""" + fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') + with open(fpath, 'rb') as infile: + data = infile.read() + + responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data) + smart_open_object = smart_open.HttpOpenRead( + smart_open.ParseUri("http://127.0.0.1/data.gz?some_param=some_val"), + ignore_extension=True, + ) + + self.assertEqual(smart_open_object.read(), data) + @responses.activate def test_http_bz2(self): """Can open bz2 via http?""" @@ -167,6 +182,18 @@ def test_http_bz2(self): # decompress the gzip and get the same md5 hash self.assertEqual(smart_open_object.read(), test_string) + @responses.activate + def test_http_bz2_ignore_extension(self): + """Can open bz2 as raw via http?""" + body = b'not really bz2 but me hiding behind extension' + responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=body) + smart_open_object = smart_open.HttpOpenRead( + smart_open.ParseUri("http://127.0.0.1/data.bz2?some_param=some_val"), + ignore_extension=True, + ) + + self.assertEqual(smart_open_object.read(), body) + class SmartOpenReadTest(unittest.TestCase): """ @@ -292,29 +319,40 @@ def test_file(self, mock_smart_open): smart_open_object = smart_open.smart_open(prefix+full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False) full_path = '/tmp/test#hash##more.txt' read_mode = "rb" smart_open_object = smart_open.smart_open(prefix+full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False) full_path = 'aa#aa' read_mode = "rb" smart_open_object = smart_open.smart_open(full_path, read_mode) smart_open_object.__iter__() # called with the correct path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False) short_path = "~/tmp/test.txt" full_path = os.path.expanduser(short_path) - smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict') + smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict', ignore_extension=False) smart_open_object.__iter__() # called with the correct expanded path? - mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict') + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False) + + @mock.patch('smart_open.smart_open_lib.file_smart_open') + def test_file_ignore_extension(self, mock_smart_open): + prefix = 'file://' + full_path = '/tmp/test.gz' + read_mode = 'rb' + smart_open_object = smart_open.smart_open(prefix + full_path, read_mode, ignore_extension=True) + smart_open_object.__iter__() + + # called with the correct path and ignore_extension? + mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=True) # couldn't find any project for mocking up HDFS data # TODO: we want to test also a content of the files, not just fnc call params @@ -485,15 +523,15 @@ def test_file_mode_mock(self, mock_file, mock_boto): # correct read modes smart_open.smart_open("blah", "r") - mock_file.assert_called_with("blah", "r", encoding=None, errors='strict') + mock_file.assert_called_with("blah", "r", encoding=None, errors='strict', ignore_extension=False) smart_open.smart_open("blah", "rb") - mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict') + mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict', ignore_extension=False) short_path = "~/blah" full_path = os.path.expanduser(short_path) smart_open.smart_open(short_path, "rb") - mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict') + mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict', ignore_extension=False) # correct write modes, incorrect scheme self.assertRaises(NotImplementedError, smart_open.smart_open, "hdfs:///blah.txt", "wb+") @@ -502,16 +540,16 @@ def test_file_mode_mock(self, mock_file, mock_boto): # correct write mode, correct file:// URI smart_open.smart_open("blah", "w") - mock_file.assert_called_with("blah", "w", encoding=None, errors='strict') + mock_file.assert_called_with("blah", "w", encoding=None, errors='strict', ignore_extension=False) smart_open.smart_open("file:///some/file.txt", "wb") - mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict') + mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict', ignore_extension=False) smart_open.smart_open("file:///some/file.txt", "wb+") - mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict') + mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict', ignore_extension=False) smart_open.smart_open("file:///some/file.txt", "w+") - mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict') + mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict', ignore_extension=False) @mock.patch('boto3.Session') def test_s3_mode_mock(self, mock_session): @@ -866,12 +904,50 @@ def test_write_read_gz(self): test_file_name = infile.name self.write_read_assertion(test_file_name) + def test_open_gz_read_ignore_extension(self): + fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') + + with open(fpath, 'rb') as fp: + expected_data = fp.read() + + with smart_open.smart_open(fpath, 'rb', ignore_extension=True) as infile: + data = infile.read() + assert data == expected_data + + with smart_open.smart_open(fpath, 'rb', ignore_extension=True) as infile: + data = gzip.GzipFile(fileobj=infile).read() + m = hashlib.md5(data) + assert m.hexdigest() == '18473e60f8c7c98d29d65bf805736a0d', \ + 'Failed to read gzip file as raw' + + def test_open_gz_rw_ignore_extension(self): + msg = b'wello horld!' + + with tempfile.NamedTemporaryFile('wb', suffix='.gz', delete=False) as infile: + test_file_name = infile.name + + with smart_open.smart_open(test_file_name, 'wb', ignore_extension=True) as fp: + fp.write(msg) + + with smart_open.smart_open(test_file_name, 'rb', ignore_extension=True) as fp: + assert fp.read() == msg + def test_write_read_bz2(self): """Can write and read bz2?""" with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile: test_file_name = infile.name self.write_read_assertion(test_file_name) + def test_open_bz2_ignore_extension(self): + msg = b'not really bz2 but me hiding behind extension' + with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile: + infile.write(msg) + fpath = infile.name + + with smart_open.smart_open(fpath, ignore_extension=True) as infile: + data = infile.read() + assert data == msg + class MultistreamsBZ2Test(unittest.TestCase): """