Skip to content

Commit

Permalink
Add new compression parameter (#609)
Browse files Browse the repository at this point in the history
* Implement a compression argument to complement ignore_ext

* Add a deprecation warning for ignore_ext argument

* Fix flake8 linting errors

* It's better to mention the problem explicitly here

Co-authored-by: Michael Penkov <m@penkov.dev>

* Improve comprehensibility of documentation

Co-authored-by: Michael Penkov <m@penkov.dev>

* Add docstrings to public constant NO_COMPRESSION

Co-authored-by: Michael Penkov <m@penkov.dev>

* Add docstrings to public constant INFER_FROM_EXTENSION

Co-authored-by: Michael Penkov <m@penkov.dev>

* Add docstrings to public function get_supported_compression_types

* Avoid race condition with bucket creation.

Co-authored-by: Michael Penkov <m@penkov.dev>

* Fix flake8 linter error introduced

* Make tests work in terms of UTF-8 encoded data

* Refactor to simplify round-trip compression verification

* Fix flake8 linter error introduced (again)

* Split compression by extension test by implicit / explicit API

* Use pytest-like assertions

* Consolidate and unnest parameter validation

* Use lowercase "see" keyword

Co-authored-by: Michael Penkov <m@penkov.dev>

* Don't parameterize tests with essentially constant values

* (Safely) create bucket in test setUp

* Remove assertions for behavior not related explicitly to compression.

Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
dmcguire81 and mpenkov authored May 5, 2021
1 parent 9089289 commit 7716c7d
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 8 deletions.
17 changes: 17 additions & 0 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@
_COMPRESSOR_REGISTRY = {}


NO_COMPRESSION = 'none'
"""Use no compression. Read/write the data as-is."""
INFER_FROM_EXTENSION = 'extension'
"""Determine the compression to use from the file extension.
See get_supported_extensions().
"""


def get_supported_compression_types():
"""Return the list of supported compression types available to open.
See compression paratemeter to smart_open.open().
"""
return [NO_COMPRESSION, INFER_FROM_EXTENSION] + [ext[1:] for ext in get_supported_extensions()]


def get_supported_extensions():
"""Return the list of file extensions for which we have registered compressors."""
return sorted(_COMPRESSOR_REGISTRY.keys())
Expand Down
37 changes: 29 additions & 8 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
# smart_open.submodule to reference to the submodules.
#
import smart_open.local_file as so_file
import smart_open.compression as so_compression

from smart_open import compression
from smart_open import doctools
from smart_open import transport

Expand Down Expand Up @@ -107,6 +107,7 @@ def open(
closefd=True,
opener=None,
ignore_ext=False,
compression=None,
transport_params=None,
):
r"""Open the URI object, returning a file-like object.
Expand Down Expand Up @@ -139,6 +140,9 @@ def open(
Mimicks built-in open parameter of the same name. Ignored.
ignore_ext: boolean, optional
Disable transparent compression/decompression based on the file extension.
compression: str, optional (see smart_open.compression.get_supported_compression_types)
Explicitly specify the compression/decompression behavior.
If you specify this parameter, then ignore_ext must not be specified.
transport_params: dict, optional
Additional parameters for the transport layer (see notes below).
Expand Down Expand Up @@ -168,13 +172,23 @@ def open(
if not isinstance(mode, str):
raise TypeError('mode should be a string')

if compression and ignore_ext:
raise ValueError('ignore_ext and compression parameters are mutually exclusive')
elif compression and compression not in so_compression.get_supported_compression_types():
raise ValueError(f'invalid compression type: {compression}')
elif ignore_ext:
compression = so_compression.NO_COMPRESSION
warnings.warn("'ignore_ext' will be deprecated in a future release", PendingDeprecationWarning)
elif compression is None:
compression = so_compression.INFER_FROM_EXTENSION

if transport_params is None:
transport_params = {}

fobj = _shortcut_open(
uri,
mode,
ignore_ext=ignore_ext,
compression=compression,
buffering=buffering,
encoding=encoding,
errors=errors,
Expand Down Expand Up @@ -219,10 +233,13 @@ def open(
raise NotImplementedError(ve.args[0])

binary = _open_binary_stream(uri, binary_mode, transport_params)
if ignore_ext:
if compression == so_compression.NO_COMPRESSION:
decompressed = binary
elif compression == so_compression.INFER_FROM_EXTENSION:
decompressed = so_compression.compression_wrapper(binary, binary_mode)
else:
decompressed = compression.compression_wrapper(binary, binary_mode)
faked_extension = f"{binary.name}.{compression.lower()}"
decompressed = so_compression.compression_wrapper(binary, binary_mode, filename=faked_extension)

if 'b' not in mode or explicit_encoding is not None:
decoded = _encoding_wrapper(
Expand Down Expand Up @@ -295,7 +312,7 @@ def transfer(char):
def _shortcut_open(
uri,
mode,
ignore_ext=False,
compression,
buffering=-1,
encoding=None,
errors=None,
Expand All @@ -309,12 +326,13 @@ def _shortcut_open(
This is only possible under the following conditions:
1. Opening a local file; and
2. Ignore extension is set to True
2. Compression is disabled
If it is not possible to use the built-in open for the specified URI, returns None.
:param str uri: A string indicating what to open.
:param str mode: The mode to pass to the open function.
:param str compression: The compression type selected.
:returns: The opened file
:rtype: file
"""
Expand All @@ -326,8 +344,11 @@ def _shortcut_open(
return None

local_path = so_file.extract_local_path(uri)
_, extension = P.splitext(local_path)
if extension in compression.get_supported_extensions() and not ignore_ext:
if compression == so_compression.INFER_FROM_EXTENSION:
_, extension = P.splitext(local_path)
if extension in so_compression.get_supported_extensions():
return None
elif compression != so_compression.NO_COMPRESSION:
return None

open_kwargs = {}
Expand Down
139 changes: 139 additions & 0 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import hashlib
import logging
import os
from smart_open.compression import INFER_FROM_EXTENSION, NO_COMPRESSION
import tempfile
import unittest
import warnings
Expand Down Expand Up @@ -1840,6 +1841,144 @@ def test(self):
self.assertEqual(expected, actual)


_RAW_DATA = "не слышны в саду даже шорохи".encode("utf-8")


@mock_s3
class HandleS3CompressionTestCase(parameterizedtestcase.ParameterizedTestCase):

def setUp(self):
s3 = boto3.resource("s3")
s3.create_bucket(Bucket="bucket").wait_until_exists()

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# 'gz' | False | Override |
# 'bz2' | False | Override |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
("gz", gzip.decompress),
("bz2", bz2.decompress),
],
)
def test_rw_compression_prescribed(self, _compression, decompressor):
"""Should read/write files with `_compression`, as prescribed."""
key = "s3://bucket/key.txt"

with smart_open.open(key, "wb", compression=_compression) as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin:
data = decompressor(fin.read())
assert data == _RAW_DATA

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# 'extension' | False | Enable |
# 'none' | False | Disable |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
(
"gz",
gzip.decompress,
),
(
"bz2",
bz2.decompress,
)
],
)
def test_rw_compression_by_extension(
self, _compression, decompressor
):
"""Should read/write files with `_compression`, explicitily inferred by file extension."""
key = f"s3://bucket/key.{_compression}"

with smart_open.open(key, "wb", compression=INFER_FROM_EXTENSION) as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin:
assert decompressor(fin.read()) == _RAW_DATA

# compression | ignore_ext | behavior |
# ----------- | ---------- | -------- |
# None | False | Enable |
# None | True | Disable |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("_compression", "decompressor"),
[
(
"gz",
gzip.decompress,
),
(
"bz2",
bz2.decompress,
),
],
)
def test_rw_compression_by_extension_deprecated(
self, _compression, decompressor
):
"""Should read/write files with `_compression`, implicitly inferred by file extension."""
key = f"s3://bucket/key.{_compression}"

with smart_open.open(key, "wb") as fout:
fout.write(_RAW_DATA)

#
# Check that what we've created is compressed as expected.
#
with smart_open.open(key, "rb", ignore_ext=True) as fin:
assert decompressor(fin.read()) == _RAW_DATA

# extension | compression | ignore_ext | behavior |
# ----------| ----------- | ---------- | -------- |
# <any> | <invalid> | <any> | Error |
# <any> | 'none' | True | Error |
# 'gz' | 'extension' | True | Error |
# 'bz2' | 'extension' | True | Error |
# <any> | 'gz' | True | Error |
# <any> | 'bz2' | True | Error |
@parameterizedtestcase.ParameterizedTestCase.parameterize(
("extension", "kwargs", "error"),
[
("", dict(compression="foo"), ValueError),
("", dict(compression="foo", ignore_ext=True), ValueError),
("", dict(compression=NO_COMPRESSION, ignore_ext=True), ValueError),
(
".gz",
dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
ValueError,
),
(
".bz2",
dict(compression=INFER_FROM_EXTENSION, ignore_ext=True),
ValueError,
),
("", dict(compression="gz", ignore_ext=True), ValueError),
("", dict(compression="bz2", ignore_ext=True), ValueError),
],
)
def test_compression_invalid(self, extension, kwargs, error):
"""Should detect and error on these invalid inputs"""
key = f"s3://bucket/key{extension}"

with pytest.raises(error):
smart_open.open(key, "wb", **kwargs)

with pytest.raises(error):
smart_open.open(key, "rb", **kwargs)


class GetBinaryModeTest(parameterizedtestcase.ParameterizedTestCase):
@parameterizedtestcase.ParameterizedTestCase.parameterize(
('mode', 'expected'),
Expand Down

0 comments on commit 7716c7d

Please sign in to comment.