Skip to content

Commit

Permalink
Propagate __exit__ call to underlying filestream
Browse files Browse the repository at this point in the history
  • Loading branch information
ddelange committed Oct 3, 2023
1 parent 2894d20 commit 0586dce
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 14 deletions.
19 changes: 5 additions & 14 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,27 @@ def register_compressor(ext, callback):
def tweak_close(outer, inner):
"""Ensure that closing the `outer` stream closes the `inner` stream as well.
Deprecated: smart_open.open().__exit__ now always calls __exit__ on the
underlying filestream.
Use this when your compression library's `close` method does not
automatically close the underlying filestream. See
https://github.com/RaRe-Technologies/smart_open/issues/630 for an
explanation why that is a problem for smart_open.
"""
outer_close = outer.close

def close_both(*args):
nonlocal inner
try:
outer_close()
finally:
if inner:
inner, fp = None, inner
fp.close()

outer.close = close_both
from smart_open.utils import propagate_wrapped_method as _propagate_wrapped_method
_propagate_wrapped_method(outer, inner, "close")


def _handle_bz2(file_obj, mode):
from bz2 import BZ2File
result = BZ2File(file_obj, mode)
tweak_close(result, file_obj)
return result


def _handle_gzip(file_obj, mode):
import gzip
result = gzip.GzipFile(fileobj=file_obj, mode=mode)
tweak_close(result, file_obj)
return result


Expand Down
3 changes: 3 additions & 0 deletions smart_open/smart_open_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from smart_open.compression import register_compressor # noqa: F401
from smart_open.utils import check_kwargs as _check_kwargs # noqa: F401
from smart_open.utils import inspect_kwargs as _inspect_kwargs # noqa: F401
from smart_open.utils import propagate_wrapped_method as _propagate_wrapped_method

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -223,6 +224,7 @@ def open(

binary = _open_binary_stream(uri, binary_mode, transport_params)
decompressed = so_compression.compression_wrapper(binary, binary_mode, compression)
_propagate_wrapped_method(decompressed, binary, "__exit__")

if 'b' not in mode or explicit_encoding is not None:
decoded = _encoding_wrapper(
Expand All @@ -232,6 +234,7 @@ def open(
errors=errors,
newline=newline,
)
_propagate_wrapped_method(decoded, decompressed, "__exit__")
else:
decoded = decompressed

Expand Down
24 changes: 24 additions & 0 deletions smart_open/tests/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,30 @@ def test_writebuffer(self):

assert actual == contents

def test_write_gz_with_error(self):
"""Does s3 multipart upload abort when for a failed compressed file upload?"""
with self.assertRaises(ValueError):
with smart_open.open(
f's3://{BUCKET_NAME}/{WRITE_KEY_NAME}',
mode="wb",
compression='.gz',
transport_params={
"multipart_upload": True,
"min_part_size": 10,
}
) as fout:
fout.write(b"test12345678test12345678")
fout.write(b"test\n")

raise ValueError("some error")

# no multipart upload was committed:
# smart_open.s3.MultipartWriter.__exit__ was called
with self.assertRaises(OSError) as cm:
smart_open.s3.open(BUCKET_NAME, WRITE_KEY_NAME, 'rb')

assert 'The specified key does not exist.' in cm.exception.args[0]


@moto.mock_s3
class SinglepartWriterTest(unittest.TestCase):
Expand Down
22 changes: 22 additions & 0 deletions smart_open/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import inspect
import logging
import types
import urllib.parse

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -189,3 +190,24 @@ def safe_urlsplit(url):

path = sr.path.replace(placeholder, '?')
return urllib.parse.SplitResult(sr.scheme, sr.netloc, path, '', '')


def propagate_wrapped_method(outer, inner, method):
"""Patch `outer` object to also execute `method` of the wrapped `inner` object."""
if outer is inner:
return
method_outer = getattr(outer, method)
method_inner = getattr(inner, method)

def method_new(*args, **kwargs):
try:
method_outer(*args, **kwargs)
finally:
method_inner(*args, **kwargs)

if isinstance(method_outer, types.BuiltinMethodType):
# e.g. __exit__()
setattr(outer, method, types.BuiltinMethodType(method_new, outer))
else:
# e.g. close()
setattr(outer, method, types.MethodType(method_new, outer))

0 comments on commit 0586dce

Please sign in to comment.