From 7716c7db6c6e7f823c34ade2f40dc7dbc2d2de5a Mon Sep 17 00:00:00 2001 From: David McGuire Date: Wed, 5 May 2021 15:54:20 -0700 Subject: [PATCH] Add new compression parameter (#609) * Implement a compression argument to complement ignore_ext * Add a deprecation warning for ignore_ext argument * Fix flake8 linting errors * It's better to mention the problem explicitly here Co-authored-by: Michael Penkov * Improve comprehensibility of documentation Co-authored-by: Michael Penkov * Add docstrings to public constant NO_COMPRESSION Co-authored-by: Michael Penkov * Add docstrings to public constant INFER_FROM_EXTENSION Co-authored-by: Michael Penkov * Add docstrings to public function get_supported_compression_types * Avoid race condition with bucket creation. Co-authored-by: Michael Penkov * Fix flake8 linter error introduced * Make tests work in terms of UTF-8 encoded data * Refactor to simplify round-trip compression verification * Fix flake8 linter error introduced (again) * Split compression by extension test by implicit / explicit API * Use pytest-like assertions * Consolidate and unnest parameter validation * Use lowercase "see" keyword Co-authored-by: Michael Penkov * Don't parameterize tests with essentially constant values * (Safely) create bucket in test setUp * Remove assertions for behavior not related explicitly to compression. Co-authored-by: Michael Penkov --- smart_open/compression.py | 17 ++++ smart_open/smart_open_lib.py | 37 ++++++-- smart_open/tests/test_smart_open.py | 139 ++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 8 deletions(-) diff --git a/smart_open/compression.py b/smart_open/compression.py index aa8b689c2..f7ee16fde 100644 --- a/smart_open/compression.py +++ b/smart_open/compression.py @@ -15,6 +15,23 @@ _COMPRESSOR_REGISTRY = {} +NO_COMPRESSION = 'none' +"""Use no compression. Read/write the data as-is.""" +INFER_FROM_EXTENSION = 'extension' +"""Determine the compression to use from the file extension. + +See get_supported_extensions(). +""" + + +def get_supported_compression_types(): + """Return the list of supported compression types available to open. + + See compression paratemeter to smart_open.open(). + """ + return [NO_COMPRESSION, INFER_FROM_EXTENSION] + [ext[1:] for ext in get_supported_extensions()] + + def get_supported_extensions(): """Return the list of file extensions for which we have registered compressors.""" return sorted(_COMPRESSOR_REGISTRY.keys()) diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index bf25f5cc3..9f970d71a 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -30,8 +30,8 @@ # smart_open.submodule to reference to the submodules. # import smart_open.local_file as so_file +import smart_open.compression as so_compression -from smart_open import compression from smart_open import doctools from smart_open import transport @@ -107,6 +107,7 @@ def open( closefd=True, opener=None, ignore_ext=False, + compression=None, transport_params=None, ): r"""Open the URI object, returning a file-like object. @@ -139,6 +140,9 @@ def open( Mimicks built-in open parameter of the same name. Ignored. ignore_ext: boolean, optional Disable transparent compression/decompression based on the file extension. + compression: str, optional (see smart_open.compression.get_supported_compression_types) + Explicitly specify the compression/decompression behavior. + If you specify this parameter, then ignore_ext must not be specified. transport_params: dict, optional Additional parameters for the transport layer (see notes below). @@ -168,13 +172,23 @@ def open( if not isinstance(mode, str): raise TypeError('mode should be a string') + if compression and ignore_ext: + raise ValueError('ignore_ext and compression parameters are mutually exclusive') + elif compression and compression not in so_compression.get_supported_compression_types(): + raise ValueError(f'invalid compression type: {compression}') + elif ignore_ext: + compression = so_compression.NO_COMPRESSION + warnings.warn("'ignore_ext' will be deprecated in a future release", PendingDeprecationWarning) + elif compression is None: + compression = so_compression.INFER_FROM_EXTENSION + if transport_params is None: transport_params = {} fobj = _shortcut_open( uri, mode, - ignore_ext=ignore_ext, + compression=compression, buffering=buffering, encoding=encoding, errors=errors, @@ -219,10 +233,13 @@ def open( raise NotImplementedError(ve.args[0]) binary = _open_binary_stream(uri, binary_mode, transport_params) - if ignore_ext: + if compression == so_compression.NO_COMPRESSION: decompressed = binary + elif compression == so_compression.INFER_FROM_EXTENSION: + decompressed = so_compression.compression_wrapper(binary, binary_mode) else: - decompressed = compression.compression_wrapper(binary, binary_mode) + faked_extension = f"{binary.name}.{compression.lower()}" + decompressed = so_compression.compression_wrapper(binary, binary_mode, filename=faked_extension) if 'b' not in mode or explicit_encoding is not None: decoded = _encoding_wrapper( @@ -295,7 +312,7 @@ def transfer(char): def _shortcut_open( uri, mode, - ignore_ext=False, + compression, buffering=-1, encoding=None, errors=None, @@ -309,12 +326,13 @@ def _shortcut_open( This is only possible under the following conditions: 1. Opening a local file; and - 2. Ignore extension is set to True + 2. Compression is disabled If it is not possible to use the built-in open for the specified URI, returns None. :param str uri: A string indicating what to open. :param str mode: The mode to pass to the open function. + :param str compression: The compression type selected. :returns: The opened file :rtype: file """ @@ -326,8 +344,11 @@ def _shortcut_open( return None local_path = so_file.extract_local_path(uri) - _, extension = P.splitext(local_path) - if extension in compression.get_supported_extensions() and not ignore_ext: + if compression == so_compression.INFER_FROM_EXTENSION: + _, extension = P.splitext(local_path) + if extension in so_compression.get_supported_extensions(): + return None + elif compression != so_compression.NO_COMPRESSION: return None open_kwargs = {} diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index 060454439..98c2ea494 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -14,6 +14,7 @@ import hashlib import logging import os +from smart_open.compression import INFER_FROM_EXTENSION, NO_COMPRESSION import tempfile import unittest import warnings @@ -1840,6 +1841,144 @@ def test(self): self.assertEqual(expected, actual) +_RAW_DATA = "не слышны в саду даже шорохи".encode("utf-8") + + +@mock_s3 +class HandleS3CompressionTestCase(parameterizedtestcase.ParameterizedTestCase): + + def setUp(self): + s3 = boto3.resource("s3") + s3.create_bucket(Bucket="bucket").wait_until_exists() + + # compression | ignore_ext | behavior | + # ----------- | ---------- | -------- | + # 'gz' | False | Override | + # 'bz2' | False | Override | + @parameterizedtestcase.ParameterizedTestCase.parameterize( + ("_compression", "decompressor"), + [ + ("gz", gzip.decompress), + ("bz2", bz2.decompress), + ], + ) + def test_rw_compression_prescribed(self, _compression, decompressor): + """Should read/write files with `_compression`, as prescribed.""" + key = "s3://bucket/key.txt" + + with smart_open.open(key, "wb", compression=_compression) as fout: + fout.write(_RAW_DATA) + + # + # Check that what we've created is compressed as expected. + # + with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin: + data = decompressor(fin.read()) + assert data == _RAW_DATA + + # compression | ignore_ext | behavior | + # ----------- | ---------- | -------- | + # 'extension' | False | Enable | + # 'none' | False | Disable | + @parameterizedtestcase.ParameterizedTestCase.parameterize( + ("_compression", "decompressor"), + [ + ( + "gz", + gzip.decompress, + ), + ( + "bz2", + bz2.decompress, + ) + ], + ) + def test_rw_compression_by_extension( + self, _compression, decompressor + ): + """Should read/write files with `_compression`, explicitily inferred by file extension.""" + key = f"s3://bucket/key.{_compression}" + + with smart_open.open(key, "wb", compression=INFER_FROM_EXTENSION) as fout: + fout.write(_RAW_DATA) + + # + # Check that what we've created is compressed as expected. + # + with smart_open.open(key, "rb", compression=NO_COMPRESSION) as fin: + assert decompressor(fin.read()) == _RAW_DATA + + # compression | ignore_ext | behavior | + # ----------- | ---------- | -------- | + # None | False | Enable | + # None | True | Disable | + @parameterizedtestcase.ParameterizedTestCase.parameterize( + ("_compression", "decompressor"), + [ + ( + "gz", + gzip.decompress, + ), + ( + "bz2", + bz2.decompress, + ), + ], + ) + def test_rw_compression_by_extension_deprecated( + self, _compression, decompressor + ): + """Should read/write files with `_compression`, implicitly inferred by file extension.""" + key = f"s3://bucket/key.{_compression}" + + with smart_open.open(key, "wb") as fout: + fout.write(_RAW_DATA) + + # + # Check that what we've created is compressed as expected. + # + with smart_open.open(key, "rb", ignore_ext=True) as fin: + assert decompressor(fin.read()) == _RAW_DATA + + # extension | compression | ignore_ext | behavior | + # ----------| ----------- | ---------- | -------- | + # | | | Error | + # | 'none' | True | Error | + # 'gz' | 'extension' | True | Error | + # 'bz2' | 'extension' | True | Error | + # | 'gz' | True | Error | + # | 'bz2' | True | Error | + @parameterizedtestcase.ParameterizedTestCase.parameterize( + ("extension", "kwargs", "error"), + [ + ("", dict(compression="foo"), ValueError), + ("", dict(compression="foo", ignore_ext=True), ValueError), + ("", dict(compression=NO_COMPRESSION, ignore_ext=True), ValueError), + ( + ".gz", + dict(compression=INFER_FROM_EXTENSION, ignore_ext=True), + ValueError, + ), + ( + ".bz2", + dict(compression=INFER_FROM_EXTENSION, ignore_ext=True), + ValueError, + ), + ("", dict(compression="gz", ignore_ext=True), ValueError), + ("", dict(compression="bz2", ignore_ext=True), ValueError), + ], + ) + def test_compression_invalid(self, extension, kwargs, error): + """Should detect and error on these invalid inputs""" + key = f"s3://bucket/key{extension}" + + with pytest.raises(error): + smart_open.open(key, "wb", **kwargs) + + with pytest.raises(error): + smart_open.open(key, "rb", **kwargs) + + class GetBinaryModeTest(parameterizedtestcase.ParameterizedTestCase): @parameterizedtestcase.ParameterizedTestCase.parameterize( ('mode', 'expected'),