Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for local file ignore_extension #173

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ It is well tested (using `moto <https://github.com/spulec/moto>`_), well documen
>>> with smart_open.smart_open('/home/radim/foo.txt.bz2', 'wb') as fout:
... fout.write("some content\n")

>>> # ignore_extension flag can be used for overriding compression extension:
>>> with smart_open.smart_open('/home/radim/rawfile.gz', ignore_extension=True) as fout:
... print 'file md5: {}'.format(hashlib.md5(fout.read()).hexdigest())


Since going over all (or select) keys in an S3 bucket is a very common operation,
there's also an extra method ``smart_open.s3_iter_bucket()`` that does this efficiently,
**processing the bucket keys in parallel** (using multiprocessing):
Expand Down
76 changes: 25 additions & 51 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def smart_open(uri, mode="rb", **kw):
# compression, if any, is determined by the filename extension (.gz, .bz2)
encoding = kw.pop('encoding', None)
errors = kw.pop('errors', DEFAULT_ERRORS)
return file_smart_open(parsed_uri.uri_path, mode, encoding=encoding, errors=errors)
ignore_extension = kw.pop('ignore_extension', False)
return file_smart_open(
parsed_uri.uri_path, mode, encoding=encoding, errors=errors,
ignore_extension=ignore_extension,
)
elif parsed_uri.scheme in ("s3", "s3n", 's3u'):
return s3_open_uri(parsed_uri, mode, **kw)
elif parsed_uri.scheme in ("hdfs", ):
Expand Down Expand Up @@ -215,7 +219,7 @@ def smart_open(uri, mode="rb", **kw):
raise TypeError('don\'t know how to handle uri %s' % repr(uri))


def s3_open_uri(parsed_uri, mode, **kwargs):
def s3_open_uri(parsed_uri, mode, ignore_extension=False, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should ignore_extension be an explicit parameter (or this should be part of kwargs), same question for s3_open_key, @mpenkov thought?

logger.debug('%r', locals())
if parsed_uri.access_id is not None:
kwargs['aws_access_key_id'] = parsed_uri.access_id
Expand All @@ -227,16 +231,6 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
if host is not None:
kwargs['endpoint_url'] = 'http://' + host

#
# TODO: this is the wrong place to handle ignore_extension.
# It should happen at the highest level in the smart_open function, because
# it influences other file systems as well, not just S3.
#
if kwargs.pop("ignore_extension", False):
codec = None
else:
codec = _detect_codec(parsed_uri.key_id)

#
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
Expand All @@ -258,7 +252,7 @@ def s3_open_uri(parsed_uri, mode, **kwargs):
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(parsed_uri.bucket_id, parsed_uri.key_id, s3_mode, **kwargs)
decompressed_fobj = _CODECS[codec](fobj, mode)
decompressed_fobj = compression_wrapper(fobj, parsed_uri.key_id, mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj

Expand All @@ -274,7 +268,7 @@ def _setup_unsecured_mode(parsed_uri, kwargs):
kwargs['calling_format'] = boto.s3.connection.OrdinaryCallingFormat()


def s3_open_key(key, mode, **kwargs):
def s3_open_key(key, mode, ignore_extension=False, **kwargs):
logger.debug('%r', locals())
#
# TODO: handle boto3 keys as well
Expand All @@ -283,11 +277,6 @@ def s3_open_key(key, mode, **kwargs):
if host is not None:
kwargs['endpoint_url'] = 'http://' + host

if kwargs.pop("ignore_extension", False):
codec = None
else:
codec = _detect_codec(key.name)

#
# Codecs work on a byte-level, so the underlying S3 object should
# always be reading bytes.
Expand All @@ -299,38 +288,15 @@ def s3_open_key(key, mode, **kwargs):
else:
raise NotImplementedError('mode %r not implemented for S3' % mode)

logging.debug('codec: %r mode: %r s3_mode: %r', codec, mode, s3_mode)
logging.debug('mode: %r s3_mode: %r', mode, s3_mode)
encoding = kwargs.get('encoding')
errors = kwargs.get('errors', DEFAULT_ERRORS)
fobj = smart_open_s3.open(key.bucket.name, key.name, s3_mode, **kwargs)
decompressed_fobj = _CODECS[codec](fobj, mode)
decompressed_fobj = compression_wrapper(fobj, key.name, mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj


def _detect_codec(filename):
if filename.endswith(".gz"):
return 'gzip'
return None


def _wrap_gzip(fileobj, mode):
return gzip.GzipFile(fileobj=fileobj, mode=mode)


def _wrap_none(fileobj, mode):
return fileobj


_CODECS = {
None: _wrap_none,
'gzip': _wrap_gzip,
#
# TODO: add support for other codecs here.
#
}


class ParseUri(object):
"""
Parse the given URI.
Expand Down Expand Up @@ -586,17 +552,23 @@ def close(self):
fileobj.close()


def compression_wrapper(file_obj, filename, mode):
def compression_wrapper(file_obj, filename, mode, ignore_extension=False):
"""
This function will wrap the file_obj with an appropriate
[de]compression mechanism based on the extension of the filename.

if `ignore_extension` is True, the original file_obj is always returned
without any [de]compression mechanism added.

file_obj must either be a filehandle object, or a class which behaves
like one.
like one.

If the filename extension isn't recognized, will simply return the original
file_obj.
"""
if ignore_extension:
return file_obj

_, ext = os.path.splitext(filename)
if ext == '.bz2':
return ClosingBZ2File(file_obj, mode)
Expand Down Expand Up @@ -642,7 +614,7 @@ def encoding_wrapper(fileobj, mode, encoding=None, errors=DEFAULT_ERRORS):
return decoder(fileobj, errors=errors)


def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS, ignore_extension=False):
"""
Stream from/to local filesystem, transparently (de)compressing gzip and bz2
files if necessary.
Expand All @@ -651,6 +623,8 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
:arg str mode: The mode in which to open the file.
:arg str encoding: The text encoding to use.
:arg str errors: The method to use when handling encoding/decoding errors.
:arg bool ignore_extension: If True, skips the auto detect compression
by file extension.
:returns: A file object
"""
#
Expand All @@ -669,7 +643,7 @@ def file_smart_open(fname, mode='rb', encoding=None, errors=DEFAULT_ERRORS):
except KeyError:
raw_mode = mode
raw_fobj = open(fname, raw_mode)
decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode)
decompressed_fobj = compression_wrapper(raw_fobj, fname, raw_mode, ignore_extension)
decoded_fobj = encoding_wrapper(decompressed_fobj, mode, encoding=encoding, errors=errors)
return decoded_fobj

Expand Down Expand Up @@ -774,7 +748,7 @@ def __exit__(self, *args, **kwargs):
self.response.close()


def HttpOpenRead(parsed_uri, mode='r', **kwargs):
def HttpOpenRead(parsed_uri, mode='r', ignore_extension=False, **kwargs):
if parsed_uri.scheme not in ('http', 'https'):
raise TypeError("can only process http/https urls")
if mode not in ('r', 'rb'):
Expand All @@ -789,9 +763,9 @@ def HttpOpenRead(parsed_uri, mode='r', **kwargs):
if fname.endswith('.gz'):
# Gzip needs a seek-able filehandle, so we need to buffer it.
buffer = make_closing(io.BytesIO)(response.binary_content())
return compression_wrapper(buffer, fname, mode)
return compression_wrapper(buffer, fname, mode, ignore_extension)
else:
return compression_wrapper(response, fname, mode)
return compression_wrapper(response, fname, mode, ignore_extension)


class WebHdfsOpenWrite(object):
Expand Down
100 changes: 88 additions & 12 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,21 @@ def test_http_gz(self):
# decompress the gzip and get the same md5 hash
self.assertEqual(m.hexdigest(), '18473e60f8c7c98d29d65bf805736a0d')

@responses.activate
def test_http_gz_ignore_extension(self):
"""Can open gzip as raw via http?"""
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')
with open(fpath, 'rb') as infile:
data = infile.read()

responses.add(responses.GET, "http://127.0.0.1/data.gz", body=data)
smart_open_object = smart_open.HttpOpenRead(
smart_open.ParseUri("http://127.0.0.1/data.gz?some_param=some_val"),
ignore_extension=True,
)

self.assertEqual(smart_open_object.read(), data)

@responses.activate
def test_http_bz2(self):
"""Can open bz2 via http?"""
Expand All @@ -167,6 +182,18 @@ def test_http_bz2(self):
# decompress the gzip and get the same md5 hash
self.assertEqual(smart_open_object.read(), test_string)

@responses.activate
def test_http_bz2_ignore_extension(self):
"""Can open bz2 as raw via http?"""
body = b'not really bz2 but me hiding behind extension'
responses.add(responses.GET, "http://127.0.0.1/data.bz2", body=body)
smart_open_object = smart_open.HttpOpenRead(
smart_open.ParseUri("http://127.0.0.1/data.bz2?some_param=some_val"),
ignore_extension=True,
)

self.assertEqual(smart_open_object.read(), body)


class SmartOpenReadTest(unittest.TestCase):
"""
Expand Down Expand Up @@ -292,29 +319,40 @@ def test_file(self, mock_smart_open):
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need also add a test for check this feature (where ignore_extension=True)


full_path = '/tmp/test#hash##more.txt'
read_mode = "rb"
smart_open_object = smart_open.smart_open(prefix+full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

full_path = 'aa#aa'
read_mode = "rb"
smart_open_object = smart_open.smart_open(full_path, read_mode)
smart_open_object.__iter__()
# called with the correct path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

short_path = "~/tmp/test.txt"
full_path = os.path.expanduser(short_path)

smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict')
smart_open_object = smart_open.smart_open(prefix+short_path, read_mode, errors='strict', ignore_extension=False)
smart_open_object.__iter__()
# called with the correct expanded path?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict')
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=False)

@mock.patch('smart_open.smart_open_lib.file_smart_open')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be maybe_.. decorator instead?

def test_file_ignore_extension(self, mock_smart_open):
prefix = 'file://'
full_path = '/tmp/test.gz'
read_mode = 'rb'
smart_open_object = smart_open.smart_open(prefix + full_path, read_mode, ignore_extension=True)
smart_open_object.__iter__()

# called with the correct path and ignore_extension?
mock_smart_open.assert_called_with(full_path, read_mode, encoding=None, errors='strict', ignore_extension=True)

# couldn't find any project for mocking up HDFS data
# TODO: we want to test also a content of the files, not just fnc call params
Expand Down Expand Up @@ -485,15 +523,15 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct read modes
smart_open.smart_open("blah", "r")
mock_file.assert_called_with("blah", "r", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "r", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("blah", "rb")
mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "rb", encoding=None, errors='strict', ignore_extension=False)

short_path = "~/blah"
full_path = os.path.expanduser(short_path)
smart_open.smart_open(short_path, "rb")
mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict')
mock_file.assert_called_with(full_path, "rb", encoding=None, errors='strict', ignore_extension=False)

# correct write modes, incorrect scheme
self.assertRaises(NotImplementedError, smart_open.smart_open, "hdfs:///blah.txt", "wb+")
Expand All @@ -502,16 +540,16 @@ def test_file_mode_mock(self, mock_file, mock_boto):

# correct write mode, correct file:// URI
smart_open.smart_open("blah", "w")
mock_file.assert_called_with("blah", "w", encoding=None, errors='strict')
mock_file.assert_called_with("blah", "w", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "wb")
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "wb", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "wb+")
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "wb+", encoding=None, errors='strict', ignore_extension=False)

smart_open.smart_open("file:///some/file.txt", "w+")
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict')
mock_file.assert_called_with("/some/file.txt", "w+", encoding=None, errors='strict', ignore_extension=False)

@mock.patch('boto3.Session')
def test_s3_mode_mock(self, mock_session):
Expand Down Expand Up @@ -866,12 +904,50 @@ def test_write_read_gz(self):
test_file_name = infile.name
self.write_read_assertion(test_file_name)

def test_open_gz_read_ignore_extension(self):
fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz')

with open(fpath, 'rb') as fp:
expected_data = fp.read()

with smart_open.smart_open(fpath, 'rb', ignore_extension=True) as infile:
data = infile.read()
assert data == expected_data

with smart_open.smart_open(fpath, 'rb', ignore_extension=True) as infile:
data = gzip.GzipFile(fileobj=infile).read()
m = hashlib.md5(data)
assert m.hexdigest() == '18473e60f8c7c98d29d65bf805736a0d', \
'Failed to read gzip file as raw'

def test_open_gz_rw_ignore_extension(self):
msg = b'wello horld!'

with tempfile.NamedTemporaryFile('wb', suffix='.gz', delete=False) as infile:
test_file_name = infile.name

with smart_open.smart_open(test_file_name, 'wb', ignore_extension=True) as fp:
fp.write(msg)

with smart_open.smart_open(test_file_name, 'rb', ignore_extension=True) as fp:
assert fp.read() == msg

def test_write_read_bz2(self):
"""Can write and read bz2?"""
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
test_file_name = infile.name
self.write_read_assertion(test_file_name)

def test_open_bz2_ignore_extension(self):
msg = b'not really bz2 but me hiding behind extension'
with tempfile.NamedTemporaryFile('wb', suffix='.bz2', delete=False) as infile:
infile.write(msg)
fpath = infile.name

with smart_open.smart_open(fpath, ignore_extension=True) as infile:
data = infile.read()
assert data == msg


class MultistreamsBZ2Test(unittest.TestCase):
"""
Expand Down