Skip to content

Commit

Permalink
add support for local file ignore_extension
Browse files Browse the repository at this point in the history
  • Loading branch information
shmuelamar committed Apr 1, 2018
1 parent ecb24c2 commit 080fcee
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 63 deletions.
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ It is well tested (using `moto <https://github.com/spulec/moto>`_), well documen
>>> with smart_open.smart_open('/home/radim/foo.txt.bz2', 'wb') as fout:
... fout.write("some content\n")
>>> # ignore_extension flag can be used for overriding compression extension:
>>> with smart_open.smart_open('/home/radim/rawfile.gz', ignore_extension=True) as fout:
... print 'file md5: {}'.format(hashlib.md5(fout.read()).hexdigest())
Since going over all (or select) keys in an S3 bucket is a very common operation,
there's also an extra method ``smart_open.s3_iter_bucket()`` that does this efficiently,
**processing the bucket keys in parallel** (using multiprocessing):
Expand Down
76 changes: 25 additions & 51 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def smart_open(uri, mode="rb", **kw):
# compression, if any, is determined by the filename extension (.gz, .bz2)
encoding = kw.pop('encoding', None)
errors = kw.pop('errors', DEFAULT_ERRORS)
return file_smart_open(parsed_uri.uri_path, mode, encoding=encoding, errors=errors)
ignore_extension = kw.pop('ignore_extension', False)
return file_smart_open(
parsed_uri.uri_path, mode, encoding=encoding, errors=errors,
ignore_extension=ignore_extension,
)
elif parsed_uri.scheme in ("s3", "s3n", 's3u'):
return s3_open_uri(parsed_uri, mode, **kw)
elif parsed_uri.scheme in ("hdfs", ):
Expand Down Expand Up @@ -215,7 +219,7 @@ def smart_open(uri, mode="rb", **kw):
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


def s3_open_uri(parsed_uri, mode, **kwargs):
def s3_open_uri(parsed_uri, mode, ignore_extension=False, **kwargs):
logger.debug('%r', locals())
if parsed_uri.access_id is not None:
kwargs['aws_access_key_id'] = parsed_uri.access_id
Expand All @@ -227,16 +231,6 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
if host is not None:
kwargs['endpoint_url'] = 'http://' + host

#
# TODO: this is the wrong place to handle ignore_extension.
# It should happen at the highest level in the smart_open function, because
# it influences other file systems as well, not just S3.
#
if kwargs.pop("ignore_extension", False):
codec = None
else:
codec = _detect_codec(parsed_uri.key_id)

#
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
Expand All @@ -258,7 +252,7 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs)
decompressed_fobj = _CODECS[codec](fobj, mode)
decompressed_fobj = compression_wrapper(fobj, parsed_uri.key_id, mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj

Expand All @@ -274,7 +268,7 @@ def _setup_unsecured_mode(parsed_uri, kwargs):
kwargs['calling_format'] = boto.s3.connection.OrdinaryCallingFormat()


def s3_open_key(key, mode, **kwargs):
def s3_open_key(key, mode, ignore_extension=False, **kwargs):
logger.debug('%r', locals())
#
# TODO: handle boto3 keys as well
Expand All @@ -283,11 +277,6 @@ def s3_open_key(key, mode, **kwargs):
if host is not None:
kwargs['endpoint_url'] = 'http://' + host

if kwargs.pop("ignore_extension", False):
codec = None
else:
codec = _detect_codec(key.name)

#
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
Expand All @@ -299,38 +288,15 @@ def s3_open_key(key, mode, **kwargs):
else:
raise NotImplementedError('mode %r not implemented for S3' % mode)

logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode)
logging.debug('mode: %r s3_mode: %r', mode, s3_mode)
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs)
decompressed_fobj = _CODECS[codec](fobj, mode)
decompressed_fobj = compression_wrapper(fobj, key.name, mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj


def _detect_codec(filename):
if filename.endswith(".gz"):
return 'gzip'
return None


def _wrap_gzip(fileobj, mode):
return gzip.GzipFile(fileobj=fileobj, mode=mode)


def _wrap_none(fileobj, mode):
return fileobj


_CODECS = {
None: _wrap_none,
'gzip': _wrap_gzip,
#
# TODO: add support for other codecs here.
#
}


class ParseUri(object):
"""
Parse the given URI.
Expand Down Expand Up @@ -586,17 +552,23 @@ def close(self):
fileobj.close()


def compression_wrapper(file_obj, filename, mode):
def compression_wrapper(file_obj, filename, mode, ignore_extension=False):
"""
This function will wrap the file_obj with an appropriate
[de]compression mechanism based on the extension of the filename.
if `ignore_extension` is True, the original file_obj is always returned
without any [de]compression mechanism added.
file_obj must either be a filehandle object, or a class which behaves
like one.
like one.
If the filename extension isn't recognized, will simply return the original
file_obj.
"""
if ignore_extension:
return file_obj

_, ext = os.path.splitext(filename)
if ext == '.bz2':
return ClosingBZ2File(file_obj, mode)
Expand Down Expand Up @@ -642,7 +614,7 @@ def encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
return decoder(fileobj, errors=errors)


def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS, ignore_extension=False):
"""
Stream from/to local filesystem, transparently (de)compressing gzip and bz2
files if necessary.
Expand All @@ -651,6 +623,8 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
:arg str mode: The mode in which to open the file.
:arg str encoding: The text encoding to use.
:arg str errors: The method to use when handling encoding/decoding errors.
:arg bool ignore_extension: If True, skips the auto detect compression
by file extension.
:returns: A file object
"""
#
Expand All @@ -669,7 +643,7 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
except KeyError:
raw_mode = mode
raw_fobj = open(fname, raw_mode)
decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode)
decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj

Expand Down Expand Up @@ -774,7 +748,7 @@ def __exit__(self, *args, **kwargs):
self.response.close()


def HttpOpenRead(parsed_uri, mode='r', **kwargs):
def HttpOpenRead(parsed_uri, mode='r', ignore_extension=False, **kwargs):
if parsed_uri.scheme not in ('http', 'https'):
raise TypeError("can only process http/https urls")
if mode not in ('r', 'rb'):
Expand All @@ -789,9 +763,9 @@ def HttpOpenRead(parsed_uri, mode='r', **kwargs):
if fname.endswith('.gz'):
# Gzip needs a seek-able filehandle, so we need to buffer it.
buffer = make_closing(io.BytesIO)(response.binary_content())
return compression_wrapper(buffer, fname, mode)
return compression_wrapper(buffer, fname, mode, ignore_extension)
else:
return compression_wrapper(response, fname, mode)
return compression_wrapper(response, fname, mode, ignore_extension)


class WebHdfsOpenWrite(object):
Expand Down
100 changes: 88 additions & 12 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ def test_http_gz(self):
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), '18473e60f8c7c98d29d65bf805736a0d')

@responses.activate
def test_http_gz_ignore_extension(self):
"""Can open gzip as raw via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data)
smart_open_object = smart_open.HttpOpenRead(
smart_open.ParseUri("http://127.0.0.1/data.gz?some_param=some_val"),
ignore_extension=True,
)

self.assertEqual(smart_open_object.read(), data)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
Expand All @@ -167,6 +182,18 @@ def test_http_bz2(self):
# decompress the gzip and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)

@responses.activate
def test_http_bz2_ignore_extension(self):
"""Can open bz2 as raw via http?"""
body = b'not really bz2 but me hiding behind extension'
responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=body)
smart_open_object = smart_open.HttpOpenRead(
smart_open.ParseUri("http://127.0.0.1/data.bz2?some_param=some_val"),
ignore_extension=True,
)

self.assertEqual(smart_open_object.read(), body)


class SmartOpenReadTest(unittest.TestCase):
"""
Expand Down Expand Up @@ -292,29 +319,40 @@ def test_file(self, mock_smart_open):
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

full_path = '/tmp/test#hash##more.txt'
read_mode = "rb"
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

full_path = 'aa#aa'
read_mode = "rb"
smart_open_object = smart_open.smart_open(full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

short_path = "~/tmp/test.txt"
full_path = os.path.expanduser(short_path)

smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict')
smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict', ignore_extension=False)
smart_open_object.__iter__()
# called with the correct expanded path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

@mock.patch('smart_open.smart_open_lib.file_smart_open')
def test_file_ignore_extension(self, mock_smart_open):
prefix = 'file://'
full_path = '/tmp/test.gz'
read_mode = 'rb'
smart_open_object = smart_open.smart_open(prefix + full_path, read_mode, ignore_extension=True)
smart_open_object.__iter__()

# called with the correct path and ignore_extension?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=True)

# couldn't find any project for mocking up HDFS data
# TODO: we want to test also a content of the files, not just fnc call params
Expand Down Expand Up @@ -485,15 +523,15 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct read modes
smart_open.smart_open("blah", "r")
mock_file.assert_called_with("blah", "r", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "r", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("blah", "rb")
mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict', ignore_extension=False)

short_path = "~/blah"
full_path = os.path.expanduser(short_path)
smart_open.smart_open(short_path, "rb")
mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict')
mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict', ignore_extension=False)

# correct write modes, incorrect scheme
self.assertRaises(NotImplementedError, smart_open.smart_open, "hdfs:///blah.txt", "wb+")
Expand All @@ -502,16 +540,16 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct write mode, correct file:// URI
smart_open.smart_open("blah", "w")
mock_file.assert_called_with("blah", "w", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "w", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "wb")
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "wb+")
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "w+")
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict', ignore_extension=False)

@mock.patch('boto3.Session')
def test_s3_mode_mock(self, mock_session):
Expand Down Expand Up @@ -866,12 +904,50 @@ def test_write_read_gz(self):
test_file_name = infile.name
self.write_read_assertion(test_file_name)

def test_open_gz_read_ignore_extension(self):
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')

with open(fpath) as fp:
expected_data = fp.read()

with smart_open.smart_open(fpath, ignore_extension=True) as infile:
data = infile.read()
assert data == expected_data

with smart_open.smart_open(fpath, ignore_extension=True) as infile:
data = gzip.GzipFile(fileobj=infile).read()
m = hashlib.md5(data)
assert m.hexdigest() == '18473e60f8c7c98d29d65bf805736a0d', \
'Failed to read gzip file as raw'

def test_open_gz_rw_ignore_extension(self):
msg = b'wello horld!'

with tempfile.NamedTemporaryFile('wb', suffix='.gz', delete=False) as infile:
test_file_name = infile.name

with smart_open.smart_open(test_file_name, 'wb', ignore_extension=True) as fp:
fp.write(msg)

with smart_open.smart_open(test_file_name, 'r', ignore_extension=True) as fp:
assert fp.read() == msg

def test_write_read_bz2(self):
"""Can write and read bz2?"""
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file_name = infile.name
self.write_read_assertion(test_file_name)

def test_open_bz2_ignore_extension(self):
msg = b'not really bz2 but me hiding behind extension'
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
infile.write(msg)
fpath = infile.name

with smart_open.smart_open(fpath, ignore_extension=True) as infile:
data = infile.read()
assert data == msg


class MultistreamsBZ2Test(unittest.TestCase):
"""
Expand Down

0 comments on commit 080fcee

Please sign in to comment.