-
-
Notifications
You must be signed in to change notification settings - Fork 383
/
compression.py
147 lines (112 loc) · 4.45 KB
/
compression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 Radim Rehurek <me@radimrehurek.com>
#
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
"""Implements the compression layer of the ``smart_open`` library."""
import logging
import os.path
logger = logging.getLogger(__name__)
_COMPRESSOR_REGISTRY = {}
NO_COMPRESSION = 'disable'
"""Use no compression. Read/write the data as-is."""
INFER_FROM_EXTENSION = 'infer_from_extension'
"""Determine the compression to use from the file extension.
See get_supported_extensions().
"""
def get_supported_compression_types():
"""Return the list of supported compression types available to open.
See compression paratemeter to smart_open.open().
"""
return [NO_COMPRESSION, INFER_FROM_EXTENSION] + get_supported_extensions()
def get_supported_extensions():
"""Return the list of file extensions for which we have registered compressors."""
return sorted(_COMPRESSOR_REGISTRY.keys())
def register_compressor(ext, callback):
"""Register a callback for transparently decompressing files with a specific extension.
Parameters
----------
ext: str
The extension. Must include the leading period, e.g. ``.gz``.
callback: callable
The callback. It must accept two position arguments, file_obj and mode.
This function will be called when ``smart_open`` is opening a file with
the specified extension.
Examples
--------
Instruct smart_open to use the `lzma` module whenever opening a file
with a .xz extension (see README.rst for the complete example showing I/O):
>>> def _handle_xz(file_obj, mode):
... import lzma
... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
>>>
>>> register_compressor('.xz', _handle_xz)
"""
if not (ext and ext[0] == '.'):
raise ValueError('ext must be a string starting with ., not %r' % ext)
ext = ext.lower()
if ext in _COMPRESSOR_REGISTRY:
logger.warning('overriding existing compression handler for %r', ext)
_COMPRESSOR_REGISTRY[ext] = callback
def tweak_close(outer, inner):
"""Ensure that closing the `outer` stream closes the `inner` stream as well.
Use this when your compression library's `close` method does not
automatically close the underlying filestream. See
https://github.com/RaRe-Technologies/smart_open/issues/630 for an
explanation why that is a problem for smart_open.
"""
outer_close = outer.close
def close_both(*args):
nonlocal inner
try:
outer_close()
finally:
if inner:
inner, fp = None, inner
fp.close()
outer.close = close_both
def _handle_bz2(file_obj, mode):
from bz2 import BZ2File
result = BZ2File(file_obj, mode)
tweak_close(result, file_obj)
return result
def _handle_gzip(file_obj, mode):
import gzip
result = gzip.GzipFile(fileobj=file_obj, mode=mode)
tweak_close(result, file_obj)
return result
def compression_wrapper(file_obj, mode, compression=INFER_FROM_EXTENSION, filename=None):
"""
Wrap `file_obj` with an appropriate [de]compression mechanism based on its file extension.
If the filename extension isn't recognized, simply return the original `file_obj` unchanged.
`file_obj` must either be a filehandle object, or a class which behaves like one.
If `filename` is specified, it will be used to extract the extension.
If not, the `file_obj.name` attribute is used as the filename.
"""
if compression == NO_COMPRESSION:
return file_obj
elif compression == INFER_FROM_EXTENSION:
try:
filename = (filename or file_obj.name).lower()
except (AttributeError, TypeError):
logger.warning(
'unable to transparently decompress %r because it '
'seems to lack a string-like .name', file_obj
)
return file_obj
_, compression = os.path.splitext(filename)
if compression in _COMPRESSOR_REGISTRY and mode.endswith('+'):
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)
try:
callback = _COMPRESSOR_REGISTRY[compression]
except KeyError:
return file_obj
else:
return callback(file_obj, mode)
#
# NB. avoid using lambda here to make stack traces more readable.
#
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)