diff --git a/.travis.yml b/.travis.yml index a53759ba..ad0597a6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,13 +8,6 @@ jobs: name: "flake8" env: TOXENV="flake8" - - python: '2.7' - env: - - SO_S3_URL: "s3://smart-open-py27-benchmark" - - SO_S3_RESULT_URL: "s3://smart-open-py27-benchmark-results" - - BOTO_CONFIG: "/dev/null" - - TOXENV: "check_keys,py27-test,py27-benchmark,py27-integration" - - python: '3.5' env: TOXENV="check_keys,py35-test,py35-integration" diff --git a/README.rst b/README.rst index 9899ea0f..51bc27f7 100644 --- a/README.rst +++ b/README.rst @@ -14,7 +14,7 @@ smart_open — utils for streaming large files in Python What? ===== -``smart_open`` is a Python 2 & Python 3 library for **efficient streaming of very large files** from/to storages such as S3, GCS, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats. +``smart_open`` is a Python 3 library for **efficient streaming of very large files** from/to storages such as S3, GCS, HDFS, WebHDFS, HTTP, HTTPS, SFTP, or local filesystem. It supports transparent, on-the-fly (de-)compression for a variety of different formats. ``smart_open`` is a drop-in replacement for Python's built-in ``open()``: it can do anything ``open`` can (100% compatible, falls back to native ``open`` wherever possible), plus lots of nifty extra stuff on top. @@ -77,6 +77,8 @@ How? ... break '\n' +.. _doctools_after_examples: + Other examples of URLs that ``smart_open`` accepts:: s3://my_bucket/my_key @@ -96,8 +98,6 @@ Other examples of URLs that ``smart_open`` accepts:: [ssh|scp|sftp]://username@host/path/file [ssh|scp|sftp]://username:password@host/path/file -.. _doctools_after_examples: - Documentation ============= @@ -407,6 +407,11 @@ This can be helpful when e.g. working with compressed files. ... print(infile.readline()[:41]) В начале июля, в чрезвычайно жаркое время +Extending ``smart_open`` +======================== + +See `this document `__. + Comments, bug reports ===================== diff --git a/extending.md b/extending.md new file mode 100644 index 00000000..b205b2b7 --- /dev/null +++ b/extending.md @@ -0,0 +1,142 @@ +# Extending `smart_open` + +This document targets potential contributors to `smart_open`. +Currently, there are two main directions for extending existing `smart_open` functionality: + +1. Add a new transport mechanism +2. Add a new compression format + +The first is by far the more challenging, and also the more welcome. + +## New transport mechanisms + +Each transport mechanism lives in its own submodule. +For example, currently we have: + +- [smart_open.local_file](smart_open/local_file.py) +- [smart_open.s3](smart_open/s3.py) +- [smart_open.ssh](smart_open/ssh.py) +- ... and others + +So, to implement a new transport mechanism, you need to create a new module. +Your module must expose the following (see [smart_open.http](smart_open/http.py) for the full implementation): + +```python +SCHEMA = ... +"""The name of the mechanism, e.g. s3, ssh, etc. + +This is the part that goes before the `://` in a URL, e.g. `s3://`.""" + +URI_EXAMPLES = ('xxx://foo/bar', 'zzz://baz/boz') +"""This will appear in the documentation of the the `parse_uri` function.""" + + +def parse_uri(uri_as_str): + """Parse the specified URI into a dict. + + At a bare minimum, the dict must have `schema` member. + """ + return dict(schema=XXX_SCHEMA, ...) + + +def open_uri(uri_as_str, mode, transport_params): + """Return a file-like object pointing to the URI. + + Parameters: + + uri_as_str: str + The URI to open + mode: str + Either "rb" or "wb". You don't need to implement text modes, + `smart_open` does that for you, outside of the transport layer. + transport_params: dict + Any additional parameters to pass to the `open` function (see below). + + """ + # + # Parse the URI using parse_uri + # Consolidate the parsed URI with transport_params, if needed + # Pass everything to the open function (see below). + # + ... + + +def open(..., mode, param1=None, param2=None, paramN=None): + """This function does the hard work. + + The keyword parameters are the transport_params from the `open_uri` + function. + + """ + ... +``` + +Have a look at the existing mechanisms to see how they work. +You may define other functions and classes as necessary for your implementation. + +Once your module is working, register it in the [smart_open.transport](smart_open/transport.py) submodule. +The `register_transport()` function updates a mapping from schemes to the modules that implement functionality for them. + +Once you've registered your new transport module, the following will happen automagically: + +1. `smart_open` will be able to open any URI supported by your module +2. The docstring for the `smart_open.open` function will contain a section + detailing the parameters for your transport module. +3. The docstring for the `parse_uri` function will include the schemas and + examples supported by your module. + +You can confirm the documentation changes by running: + + python -c 'help("smart_open")' + +and verify that documentation for your new submodule shows up. + +### What's the difference between the `open_uri` and `open` functions? + +There are several key differences between the two. + +First, the parameters to `open_uri` are the same for _all transports_. +On the other hand, the parameters to the `open` function can differ from transport to transport. + +Second, the responsibilities of the two functions are also different. +The `open` function opens the remote object. +The `open_uri` function deals with parsing transport-specific details out of the URI, and then delegates to `open`. + +The `open` function contains documentation for transport parameters. +This documentation gets parsed by the `doctools` module and appears in various docstrings. + +Some of these differences are by design; others as a consequence of evolution. + +## New compression mechanisms + +The compression layer is self-contained in the `smart_open.compression` submodule. + +To add support for a new compressor: + +- Create a new function to handle your compression format (given an extension) +- Add your compressor to the registry + +For example: + +```python +def _handle_xz(file_obj, mode): + import lzma + return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ) + + +register_compressor('.xz', _handle_xz) +``` + +There are many compression formats out there, and supporting all of them is beyond the scope of `smart_open`. +We want our code's functionality to cover the bare minimum required to satisfy 80% of our users. +We leave the remaining 20% of users with the ability to deal with compression in their own code, using the trivial mechanism described above. + +Documentation +------------- + +Once you've contributed your extension, please add it to the documentation so that it is discoverable for other users. +Some notable files: + +- setup.py: See the `description` keyword. Not all contributions will affect this. +- README.rst +- howto.md (if your extension solves a specific problem that doesn't get covered by other documentation) diff --git a/setup.py b/setup.py index 01e065f5..a33a24b3 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,6 @@ import io import os -import sys from setuptools import setup, find_packages @@ -17,14 +16,12 @@ def _get_version(): curr_dir = os.path.dirname(os.path.abspath(__file__)) with open(os.path.join(curr_dir, 'smart_open', 'version.py')) as fin: - # - # __version__ = '1.8.4' - # line = fin.readline().strip() parts = line.split(' ') + assert len(parts) == 3 assert parts[0] == '__version__' assert parts[1] == '=' - return parts[2][1:-1] + return parts[2].strip('\'"') # @@ -59,8 +56,6 @@ def read(fname): 'boto3', 'google-cloud-storage', ] -if sys.version_info[0] == 2: - install_requires.append('bz2file') setup( name='smart_open', @@ -100,7 +95,6 @@ def read(fname): 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', diff --git a/smart_open/__init__.py b/smart_open/__init__.py index 3d7ed155..4f41fb88 100644 --- a/smart_open/__init__.py +++ b/smart_open/__init__.py @@ -16,6 +16,7 @@ The main functions are: * `open()`, which opens the given file for reading/writing +* `parse_uri()` * `s3_iter_bucket()`, which goes over all keys in an S3 bucket in parallel * `register_compressor()`, which registers callbacks for transparent compressor handling @@ -24,9 +25,16 @@ import logging from smart_open import version -from .smart_open_lib import open, smart_open, register_compressor +from .smart_open_lib import open, parse_uri, smart_open, register_compressor from .s3 import iter_bucket as s3_iter_bucket -__all__ = ['open', 'smart_open', 's3_iter_bucket', 'register_compressor'] + +__all__ = [ + 'open', + 'parse_uri', + 'register_compressor', + 's3_iter_bucket', + 'smart_open', +] __version__ = version.__version__ diff --git a/smart_open/compression.py b/smart_open/compression.py new file mode 100644 index 00000000..d459761c --- /dev/null +++ b/smart_open/compression.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements the compression layer of the ``smart_open`` library.""" +import logging +import os.path + +logger = logging.getLogger(__name__) + + +_COMPRESSOR_REGISTRY = {} + + +def get_supported_extensions(): + """Return the list of file extensions for which we have registered compressors.""" + return sorted(_COMPRESSOR_REGISTRY.keys()) + + +def register_compressor(ext, callback): + """Register a callback for transparently decompressing files with a specific extension. + + Parameters + ---------- + ext: str + The extension. Must include the leading period, e.g. ``.gz``. + callback: callable + The callback. It must accept two position arguments, file_obj and mode. + This function will be called when ``smart_open`` is opening a file with + the specified extension. + + Examples + -------- + + Instruct smart_open to use the `lzma` module whenever opening a file + with a .xz extension (see README.rst for the complete example showing I/O): + + >>> def _handle_xz(file_obj, mode): + ... import lzma + ... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ) + >>> + >>> register_compressor('.xz', _handle_xz) + + """ + if not (ext and ext[0] == '.'): + raise ValueError('ext must be a string starting with ., not %r' % ext) + if ext in _COMPRESSOR_REGISTRY: + logger.warning('overriding existing compression handler for %r', ext) + _COMPRESSOR_REGISTRY[ext] = callback + + +def _handle_bz2(file_obj, mode): + from bz2 import BZ2File + return BZ2File(file_obj, mode) + + +def _handle_gzip(file_obj, mode): + import gzip + return gzip.GzipFile(fileobj=file_obj, mode=mode) + + +def compression_wrapper(file_obj, mode): + """ + This function will wrap the file_obj with an appropriate + [de]compression mechanism based on the extension of the filename. + + file_obj must either be a filehandle object, or a class which behaves + like one. It must have a .name attribute. + + If the filename extension isn't recognized, will simply return the original + file_obj. + """ + + try: + _, ext = os.path.splitext(file_obj.name) + except (AttributeError, TypeError): + logger.warning( + 'unable to transparently decompress %r because it ' + 'seems to lack a string-like .name', file_obj + ) + return file_obj + + if ext in _COMPRESSOR_REGISTRY and mode.endswith('+'): + raise ValueError('transparent (de)compression unsupported for mode %r' % mode) + + try: + callback = _COMPRESSOR_REGISTRY[ext] + except KeyError: + return file_obj + else: + return callback(file_obj, mode) + + +# +# NB. avoid using lambda here to make stack traces more readable. +# +register_compressor('.bz2', _handle_bz2) +register_compressor('.gz', _handle_gzip) diff --git a/smart_open/concurrency.py b/smart_open/concurrency.py new file mode 100644 index 00000000..4e72aec7 --- /dev/null +++ b/smart_open/concurrency.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# + +"""Common functionality for concurrent processing. + +The main entry point is :func:`create_pool`. +""" + +import contextlib +import logging +import warnings + +logger = logging.getLogger(__name__) + +# AWS Lambda environments do not support multiprocessing.Queue or multiprocessing.Pool. +# However they do support Threads and therefore concurrent.futures's ThreadPoolExecutor. +# We use this flag to allow python 2 backward compatibility, where concurrent.futures doesn't exist. +_CONCURRENT_FUTURES = False +try: + import concurrent.futures + _CONCURRENT_FUTURES = True +except ImportError: + warnings.warn("concurrent.futures could not be imported and won't be used") + +# Multiprocessing is unavailable in App Engine (and possibly other sandboxes). +# The only method currently relying on it is iter_bucket, which is instructed +# whether to use it by the MULTIPROCESSING flag. +_MULTIPROCESSING = False +try: + import multiprocessing.pool + _MULTIPROCESSING = True +except ImportError: + warnings.warn("multiprocessing could not be imported and won't be used") + + +class DummyPool(object): + """A class that mimics multiprocessing.pool.Pool for our purposes.""" + def imap_unordered(self, function, items): + return map(function, items) + + def terminate(self): + pass + + +class ConcurrentFuturesPool(object): + """A class that mimics multiprocessing.pool.Pool but uses concurrent futures instead of processes.""" + def __init__(self, max_workers): + self.executor = concurrent.futures.ThreadPoolExecutor(max_workers) + + def imap_unordered(self, function, items): + futures = [self.executor.submit(function, item) for item in items] + for future in concurrent.futures.as_completed(futures): + yield future.result() + + def terminate(self): + self.executor.shutdown(wait=True) + + +@contextlib.contextmanager +def create_pool(processes=1): + if _MULTIPROCESSING and processes: + logger.info("creating multiprocessing pool with %i workers", processes) + pool = multiprocessing.pool.Pool(processes=processes) + elif _CONCURRENT_FUTURES and processes: + logger.info("creating concurrent futures pool with %i workers", processes) + pool = ConcurrentFuturesPool(max_workers=processes) + else: + logger.info("creating dummy pool") + pool = DummyPool() + yield pool + pool.terminate() diff --git a/smart_open/constants.py b/smart_open/constants.py new file mode 100644 index 00000000..1ffa14e3 --- /dev/null +++ b/smart_open/constants.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# + +"""Some universal constants that are common to I/O operations.""" + + +READ_BINARY = 'rb' + +WRITE_BINARY = 'wb' + +BINARY_MODES = (READ_BINARY, WRITE_BINARY) + +BINARY_NEWLINE = b'\n' + +WHENCE_START = 0 + +WHENCE_CURRENT = 1 + +WHENCE_END = 2 + +WHENCE_CHOICES = (WHENCE_START, WHENCE_CURRENT, WHENCE_END) diff --git a/smart_open/doctools.py b/smart_open/doctools.py index dd5d7490..bf9c3c7e 100644 --- a/smart_open/doctools.py +++ b/smart_open/doctools.py @@ -15,6 +15,10 @@ import io import os.path import re +import warnings + +from . import compression +from . import transport def extract_kwargs(docstring): @@ -73,8 +77,12 @@ def extract_kwargs(docstring): # 1. Find the underlined 'Parameters' section # 2. Once there, continue parsing parameters until we hit an empty line # - while lines[0] != 'Parameters': + while lines and lines[0] != 'Parameters': lines.pop(0) + + if not lines: + return [] + lines.pop(0) lines.pop(0) @@ -156,3 +164,56 @@ def extract_examples_from_readme_rst(indent=' '): return ''.join([indent + re.sub('^ ', '', l) for l in lines]) except Exception: return indent + 'See README.rst' + + +def tweak_docstrings(open_function, parse_uri_function): + # + # The docstring can be None if -OO was passed to the interpreter. + # + if not (open_function.__doc__ and parse_uri_function.__doc__): + warnings.warn( + 'docstrings for smart_open function are missing, ' + 'see https://github.com/RaRe-Technologies/smart_open' + '/blob/master/README.rst if you need documentation' + ) + return + + substrings = {} + schemes = io.StringIO() + seen_examples = set() + uri_examples = io.StringIO() + + for scheme, submodule in sorted(transport._REGISTRY.items()): + if scheme == transport.NO_SCHEME: + continue + + schemes.write(' * %s\n' % scheme) + + try: + fn = submodule.open + except AttributeError: + substrings[scheme] = '' + else: + kwargs = extract_kwargs(fn.__doc__) + substrings[scheme] = to_docstring(kwargs, lpad=u' ') + + try: + examples = submodule.URI_EXAMPLES + except AttributeError: + continue + else: + for e in examples: + if e not in seen_examples: + uri_examples.write(' * %s\n' % e) + seen_examples.add(e) + + substrings['codecs'] = '\n'.join( + [' * %s' % e for e in compression.get_supported_extensions()] + ) + substrings['examples'] = extract_examples_from_readme_rst() + + open_function.__doc__ = open_function.__doc__ % substrings + parse_uri_function.__doc__ = parse_uri_function.__doc__ % dict( + schemes=schemes.getvalue(), + uri_examples=uri_examples.getvalue(), + ) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index c0c76cbc..9583fc7c 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -9,35 +9,25 @@ import io import logging -import sys +import urllib.parse import google.cloud.exceptions import google.cloud.storage import google.auth.transport.requests as google_requests -import six import smart_open.bytebuffer -import smart_open.s3 +import smart_open.utils -logger = logging.getLogger(__name__) - -_READ_BINARY = 'rb' -_WRITE_BINARY = 'wb' +from smart_open import constants -_MODES = (_READ_BINARY, _WRITE_BINARY) -"""Allowed I/O modes for working with GCS.""" +logger = logging.getLogger(__name__) -_BINARY_TYPES = (six.binary_type, bytearray) +_BINARY_TYPES = (bytes, bytearray, memoryview) """Allowed binary buffer types for writing to the underlying GCS stream""" -if sys.version_info >= (2, 7): - _BINARY_TYPES = (six.binary_type, bytearray, memoryview) - -_BINARY_NEWLINE = b'\n' - _UNKNOWN_FILE_SIZE = '*' -SUPPORTED_SCHEME = "gs" +SCHEME = "gs" """Supported scheme for GCS""" _MIN_MIN_PART_SIZE = _REQUIRED_CHUNK_MULTIPLE = 256 * 1024 @@ -49,22 +39,14 @@ DEFAULT_BUFFER_SIZE = 256 * 1024 """Default buffer size for working with GCS""" -START = 0 -"""Seek to the absolute start of a GCS file""" - -CURRENT = 1 -"""Seek relative to the current positive of a GCS file""" - -END = 2 -"""Seek relative to the end of a GCS file""" - -_WHENCE_CHOICES = (START, CURRENT, END) - _UPLOAD_INCOMPLETE_STATUS_CODE = 308 _UPLOAD_COMPLETE_STATUS_CODES = (200, 201) def _make_range_string(start, stop=None, end=_UNKNOWN_FILE_SIZE): + # + # GCS seems to violate RFC-2616 (see utils.make_range_string), so we + # need a separate implementation. # # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks # @@ -104,6 +86,20 @@ def from_response(cls, response, part_num, content_length, total_size, headers): return cls(msg, response.status_code, response.text) +def parse_uri(uri_as_string): + sr = urllib.parse.urlsplit(uri_as_string) + assert sr.scheme == SCHEME + bucket_id = sr.netloc + blob_id = sr.path.lstrip('/') + return dict(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id) + + +def open_uri(uri, mode, transport_params): + parsed_uri = parse_uri(uri) + kwargs = smart_open.utils.check_kwargs(open, transport_params) + return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) + + def open( bucket_id, blob_id, @@ -130,15 +126,15 @@ def open( The GCS client to use when working with google-cloud-storage. """ - if mode == _READ_BINARY: + if mode == constants.READ_BINARY: return SeekableBufferedInputBase( bucket_id, blob_id, buffer_size=buffer_size, - line_terminator=_BINARY_NEWLINE, + line_terminator=constants.BINARY_NEWLINE, client=client, ) - elif mode == _WRITE_BINARY: + elif mode == constants.WRITE_BINARY: return BufferedOutputBase( bucket_id, blob_id, @@ -204,7 +200,7 @@ def __init__( bucket, key, buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=_BINARY_NEWLINE, + line_terminator=constants.BINARY_NEWLINE, client=None, # type: google.cloud.storage.Client ): if client is None: @@ -256,7 +252,7 @@ def detach(self): """Unsupported.""" raise io.UnsupportedOperation - def seek(self, offset, whence=START): + def seek(self, offset, whence=constants.WHENCE_START): """Seek to the specified position. :param int offset: The offset in bytes. @@ -264,16 +260,16 @@ def seek(self, offset, whence=START): Returns the position after seeking.""" logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in _WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % _WHENCE_CHOICES) + if whence not in constants.WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) - if whence == START: + if whence == constants.WHENCE_START: new_position = offset - elif whence == CURRENT: + elif whence == constants.WHENCE_CURRENT: new_position = self._current_pos + offset else: new_position = self._size + offset - new_position = smart_open.s3.clamp(new_position, 0, self._size) + new_position = smart_open.utils.clamp(new_position, 0, self._size) self._current_pos = new_position self._raw_reader.seek(new_position) logger.debug('current_pos: %r', self._current_pos) diff --git a/smart_open/hdfs.py b/smart_open/hdfs.py index 2485685f..a4d892cd 100644 --- a/smart_open/hdfs.py +++ b/smart_open/hdfs.py @@ -17,9 +17,40 @@ import io import logging import subprocess +import urllib.parse + +from smart_open import utils logger = logging.getLogger(__name__) +SCHEME = 'hdfs' + +URI_EXAMPLES = ( + 'hdfs:///path/file', + 'hdfs://path/file', +) + + +def parse_uri(uri_as_string): + split_uri = urllib.parse.urlsplit(uri_as_string) + assert split_uri.scheme == SCHEME + + uri_path = split_uri.netloc + split_uri.path + uri_path = "/" + uri_path.lstrip("/") + if not uri_path: + raise RuntimeError("invalid HDFS URI: %r" % uri_as_string) + + return dict(scheme=SCHEME, uri_path=uri_path) + + +def open_uri(uri, mode, transport_params): + utils.check_kwargs(open, transport_params) + + parsed_uri = parse_uri(uri) + fobj = open(parsed_uri['uri_path'], mode) + fobj.name = parsed_uri['uri_path'].split('/')[-1] + return fobj + def open(uri, mode): if mode == 'rb': diff --git a/smart_open/http.py b/smart_open/http.py index 7530a942..975ec262 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -9,12 +9,16 @@ import io import logging +import os.path +import urllib.parse import requests -from smart_open import bytebuffer, s3 +from smart_open import bytebuffer, constants +import smart_open.utils DEFAULT_BUFFER_SIZE = 128 * 1024 +SCHEMES = ('http', 'https') logger = logging.getLogger(__name__) @@ -28,6 +32,20 @@ """ +def parse_uri(uri_as_string): + split_uri = urllib.parse.urlsplit(uri_as_string) + assert split_uri.scheme in SCHEMES + + uri_path = split_uri.netloc + split_uri.path + uri_path = "/" + uri_path.lstrip("/") + return dict(scheme=split_uri.scheme, uri_path=uri_path) + + +def open_uri(uri, mode, transport_params): + kwargs = smart_open.utils.check_kwargs(open, transport_params) + return open(uri, mode, **kwargs) + + def open(uri, mode, kerberos=False, user=None, password=None, headers=None): """Implement streamed reader from a web site. @@ -56,11 +74,13 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None): unauthenticated, unless set separately in headers. """ - if mode == 'rb': - return SeekableBufferedInputBase( + if mode == constants.READ_BINARY: + fobj = SeekableBufferedInputBase( uri, mode, kerberos=kerberos, user=user, password=password, headers=headers ) + fobj.name = os.path.basename(urllib.parse.urlparse(uri).path) + return fobj else: raise NotImplementedError('http support for mode %r not implemented' % mode) @@ -233,20 +253,20 @@ def seek(self, offset, whence=0): Returns the position after seeking.""" logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in s3.WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % s3.WHENCE_CHOICES) + if whence not in constants.WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) if not self.seekable(): raise OSError - if whence == s3.START: + if whence == constants.WHENCE_START: new_pos = offset - elif whence == s3.CURRENT: + elif whence == constants.WHENCE_CURRENT: new_pos = self._current_pos + offset - elif whence == s3.END: + elif whence == constants.WHENCE_END: new_pos = self.content_length + offset - new_pos = s3.clamp(new_pos, 0, self.content_length) + new_pos = smart_open.utils.clamp(new_pos, 0, self.content_length) if self._current_pos == new_pos: return self._current_pos @@ -282,7 +302,7 @@ def truncate(self, size=None): def _partial_request(self, start_pos=None): if start_pos is not None: - self.headers.update({"range": s3.make_range_string(start_pos)}) + self.headers.update({"range": smart_open.utils.make_range_string(start_pos)}) response = requests.get(self.url, auth=self.auth, stream=True, headers=self.headers) return response diff --git a/smart_open/local_file.py b/smart_open/local_file.py new file mode 100644 index 00000000..e5f5c5aa --- /dev/null +++ b/smart_open/local_file.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Implements the transport for the file:// schema.""" +import io +import os.path + +SCHEME = 'file' + +URI_EXAMPLES = ( + './local/path/file', + '~/local/path/file', + 'local/path/file', + './local/path/file.gz', + 'file:///home/user/file', + 'file:///home/user/file.bz2', +) + + +open = io.open + + +def parse_uri(uri_as_string): + local_path = extract_local_path(uri_as_string) + return dict(scheme=SCHEME, uri_path=local_path) + + +def open_uri(uri_as_string, mode, transport_params): + parsed_uri = parse_uri(uri_as_string) + fobj = io.open(parsed_uri['uri_path'], mode) + return fobj + + +def extract_local_path(uri_as_string): + if uri_as_string.startswith('file://'): + local_path = uri_as_string.replace('file://', '', 1) + else: + local_path = uri_as_string + return os.path.expanduser(local_path) diff --git a/smart_open/s3.py b/smart_open/s3.py index 51f31eb0..faa6ed20 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -8,77 +8,184 @@ """Implements file-like objects for reading and writing from/to S3.""" import io -import contextlib import functools import logging import time -import warnings +import urllib.parse +import boto import boto3 import botocore.client import botocore.exceptions -import six import smart_open.bytebuffer +import smart_open.concurrency +import smart_open.utils -logger = logging.getLogger(__name__) - -# AWS Lambda environments do not support multiprocessing.Queue or multiprocessing.Pool. -# However they do support Threads and therefore concurrent.futures's ThreadPoolExecutor. -# We use this flag to allow python 2 backward compatibility, where concurrent.futures doesn't exist. -_CONCURRENT_FUTURES = False -try: - import concurrent.futures - _CONCURRENT_FUTURES = True -except ImportError: - warnings.warn("concurrent.futures could not be imported and won't be used") - -# Multiprocessing is unavailable in App Engine (and possibly other sandboxes). -# The only method currently relying on it is iter_bucket, which is instructed -# whether to use it by the MULTIPROCESSING flag. -_MULTIPROCESSING = False -try: - import multiprocessing.pool - _MULTIPROCESSING = True -except ImportError: - warnings.warn("multiprocessing could not be imported and won't be used") +from smart_open import constants +logger = logging.getLogger(__name__) DEFAULT_MIN_PART_SIZE = 50 * 1024**2 """Default minimum part size for S3 multipart uploads""" MIN_MIN_PART_SIZE = 5 * 1024 ** 2 """The absolute minimum permitted by Amazon.""" -READ_BINARY = 'rb' -WRITE_BINARY = 'wb' -MODES = (READ_BINARY, WRITE_BINARY) -"""Allowed I/O modes for working with S3.""" - -BINARY_NEWLINE = b'\n' -SUPPORTED_SCHEMES = ("s3", "s3n", 's3u', "s3a") +SCHEMES = ("s3", "s3n", 's3u', "s3a") +DEFAULT_PORT = 443 +DEFAULT_HOST = 's3.amazonaws.com' DEFAULT_BUFFER_SIZE = 128 * 1024 -START = 0 -CURRENT = 1 -END = 2 -WHENCE_CHOICES = [START, CURRENT, END] +URI_EXAMPLES = ( + 's3://my_bucket/my_key', + 's3://my_key:my_secret@my_bucket/my_key', + 's3://my_key:my_secret@my_server:my_port@my_bucket/my_key', +) _UPLOAD_ATTEMPTS = 6 _SLEEP_SECONDS = 10 -def clamp(value, minval, maxval): - return max(min(value, maxval), minval) +def _safe_urlsplit(url): + """This is a hack to prevent the regular urlsplit from splitting around question marks. + + A question mark (?) in a URL typically indicates the start of a + querystring, and the standard library's urlparse function handles the + querystring separately. Unfortunately, question marks can also appear + _inside_ the actual URL for some schemas like S3. + + Replaces question marks with newlines prior to splitting. This is safe because: + + 1. The standard library's urlsplit completely ignores newlines + 2. Raw newlines will never occur in innocuous URLs. They are always URL-encoded. + See Also + -------- + https://github.com/python/cpython/blob/3.7/Lib/urllib/parse.py + https://github.com/RaRe-Technologies/smart_open/issues/285 + """ + sr = urllib.parse.urlsplit(url.replace('?', '\n'), allow_fragments=False) + return urllib.parse.SplitResult(sr.scheme, sr.netloc, sr.path.replace('\n', '?'), '', '') -def make_range_string(start, stop=None): + +def parse_uri(uri_as_string): + # + # Restrictions on bucket names and labels: # - # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 + # - Bucket names must be at least 3 and no more than 63 characters long. + # - Bucket names must be a series of one or more labels. + # - Adjacent labels are separated by a single period (.). + # - Bucket names can contain lowercase letters, numbers, and hyphens. + # - Each label must start and end with a lowercase letter or a number. # - if stop is None: - return 'bytes=%d-' % start - return 'bytes=%d-%d' % (start, stop) + # We use the above as a guide only, and do not perform any validation. We + # let boto3 take care of that for us. + # + split_uri = _safe_urlsplit(uri_as_string) + assert split_uri.scheme in SCHEMES + + port = DEFAULT_PORT + host = boto.config.get('s3', 'host', DEFAULT_HOST) + ordinary_calling_format = False + # + # These defaults tell boto3 to look for credentials elsewhere + # + access_id, access_secret = None, None + + # + # Common URI template [secret:key@][host[:port]@]bucket/object + # + # The urlparse function doesn't handle the above schema, so we have to do + # it ourselves. + # + uri = split_uri.netloc + split_uri.path + + if '@' in uri and ':' in uri.split('@')[0]: + auth, uri = uri.split('@', 1) + access_id, access_secret = auth.split(':') + + head, key_id = uri.split('/', 1) + if '@' in head and ':' in head: + ordinary_calling_format = True + host_port, bucket_id = head.split('@') + host, port = host_port.split(':', 1) + port = int(port) + elif '@' in head: + ordinary_calling_format = True + host, bucket_id = head.split('@') + else: + bucket_id = head + + return dict( + scheme=split_uri.scheme, + bucket_id=bucket_id, + key_id=key_id, + port=port, + host=host, + ordinary_calling_format=ordinary_calling_format, + access_id=access_id, + access_secret=access_secret, + ) + + +def _consolidate_params(uri, transport_params): + """Consolidates the parsed Uri with the additional parameters. + + This is necessary because the user can pass some of the parameters can in + two different ways: + + 1) Via the URI itself + 2) Via the transport parameters + + These are not mutually exclusive, but we have to pick one over the other + in a sensible way in order to proceed. + + """ + transport_params = dict(transport_params) + + session = transport_params.get('session') + if session is not None and (uri['access_id'] or uri['access_secret']): + logger.warning( + 'ignoring credentials parsed from URL because they conflict with ' + 'transport_params.session. Set transport_params.session to None ' + 'to suppress this warning.' + ) + uri.update(access_id=None, access_secret=None) + elif (uri['access_id'] and uri['access_secret']): + transport_params['session'] = boto3.Session( + aws_access_key_id=uri['access_id'], + aws_secret_access_key=uri['access_secret'], + ) + uri.update(access_id=None, access_secret=None) + + if uri['host'] != DEFAULT_HOST: + endpoint_url = 'https://%(host)s:%(port)d' % uri + _override_endpoint_url(transport_params, endpoint_url) + + return uri, transport_params + + +def _override_endpoint_url(transport_params, url): + try: + resource_kwargs = transport_params['resource_kwargs'] + except KeyError: + resource_kwargs = transport_params['resource_kwargs'] = {} + + if resource_kwargs.get('endpoint_url'): + logger.warning( + 'ignoring endpoint_url parsed from URL because it conflicts ' + 'with transport_params.resource_kwargs.endpoint_url. ' + ) + else: + resource_kwargs.update(endpoint_url=url) + + +def open_uri(uri, mode, transport_params): + parsed_uri = parse_uri(uri) + parsed_uri, transport_params = _consolidate_params(parsed_uri, transport_params) + kwargs = smart_open.utils.check_kwargs(open, transport_params) + return open(parsed_uri['bucket_id'], parsed_uri['key_id'], mode, **kwargs) def open( @@ -135,13 +242,13 @@ def open( """ logger.debug('%r', locals()) - if mode not in MODES: - raise NotImplementedError('bad mode: %r expected one of %r' % (mode, MODES)) + if mode not in constants.BINARY_MODES: + raise NotImplementedError('bad mode: %r expected one of %r' % (mode, constants.BINARY_MODES)) - if (mode == WRITE_BINARY) and (version_id is not None): + if (mode == constants.WRITE_BINARY) and (version_id is not None): raise ValueError("version_id must be None when writing") - if mode == READ_BINARY: + if mode == constants.READ_BINARY: fileobj = Reader( bucket_id, key_id, @@ -151,7 +258,7 @@ def open( resource_kwargs=resource_kwargs, object_kwargs=object_kwargs, ) - elif mode == WRITE_BINARY: + elif mode == constants.WRITE_BINARY: if multipart_upload: fileobj = MultipartWriter( bucket_id, @@ -171,6 +278,8 @@ def open( ) else: assert False, 'unexpected mode: %r' % mode + + fileobj.name = key_id return fileobj @@ -218,7 +327,7 @@ def seek(self, position): def _load_body(self): """Build a continuous connection with the remote peer starts from the current postion. """ - range_string = make_range_string(self._position) + range_string = smart_open.utils.make_range_string(self._position) logger.debug('content_length: %r range_string: %r', self._content_length, range_string) if self._position == self._content_length == 0 or self._position == self._content_length: @@ -266,7 +375,7 @@ class Reader(io.BufferedIOBase): Implements the io.BufferedIOBase interface of the standard library.""" def __init__(self, bucket, key, version_id=None, buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=BINARY_NEWLINE, session=None, resource_kwargs=None, + line_terminator=constants.BINARY_NEWLINE, session=None, resource_kwargs=None, object_kwargs=None): self._buffer_size = buffer_size @@ -384,7 +493,7 @@ def seekable(self): We offer only seek support, and no truncate support.""" return True - def seek(self, offset, whence=START): + def seek(self, offset, whence=constants.WHENCE_START): """Seek to the specified position. :param int offset: The offset in bytes. @@ -392,16 +501,16 @@ def seek(self, offset, whence=START): Returns the position after seeking.""" logger.debug('seeking to offset: %r whence: %r', offset, whence) - if whence not in WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % WHENCE_CHOICES) + if whence not in constants.WHENCE_CHOICES: + raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) - if whence == START: + if whence == constants.WHENCE_START: new_position = offset - elif whence == CURRENT: + elif whence == constants.WHENCE_CURRENT: new_position = self._current_pos + offset else: new_position = self._content_length + offset - new_position = clamp(new_position, 0, self._content_length) + new_position = smart_open.utils.clamp(new_position, 0, self._content_length) self._current_pos = new_position self._raw_reader.seek(new_position) logger.debug('new_position: %r', self._current_pos) @@ -911,7 +1020,7 @@ def iter_bucket( retries=retries, **session_kwargs) - with _create_process_pool(processes=workers) as pool: + with smart_open.concurrency.create_pool(processes=workers) as pool: result_iterator = pool.imap_unordered(download_key, key_iterator) for key_no, (key, content) in enumerate(result_iterator): if True or key_no % 1000 == 0: @@ -994,41 +1103,3 @@ def _download_fileobj(bucket, key_name): buf = io.BytesIO() bucket.download_fileobj(key_name, buf) return buf.getvalue() - - -class DummyPool(object): - """A class that mimics multiprocessing.pool.Pool for our purposes.""" - def imap_unordered(self, function, items): - return six.moves.map(function, items) - - def terminate(self): - pass - - -class ConcurrentFuturesPool(object): - """A class that mimics multiprocessing.pool.Pool but uses concurrent futures instead of processes.""" - def __init__(self, max_workers): - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers) - - def imap_unordered(self, function, items): - futures = [self.executor.submit(function, item) for item in items] - for future in concurrent.futures.as_completed(futures): - yield future.result() - - def terminate(self): - self.executor.shutdown(wait=True) - - -@contextlib.contextmanager -def _create_process_pool(processes=1): - if _MULTIPROCESSING and processes: - logger.info("creating multiprocessing pool with %i workers", processes) - pool = multiprocessing.pool.Pool(processes=processes) - elif _CONCURRENT_FUTURES and processes: - logger.info("creating concurrent futures pool with %i workers", processes) - pool = ConcurrentFuturesPool(max_workers=processes) - else: - logger.info("creating dummy pool") - pool = DummyPool() - yield pool - pool.terminate() diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index 741cf5bd..7b2a7bd6 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -10,190 +10,108 @@ The main functions are: - * `open()` - * `register_compressor()` + * ``parse_uri()`` + * ``open()`` """ import codecs import collections import logging -import io -import importlib -import inspect import os import os.path as P +import pathlib +import urllib.parse import warnings import sys import boto3 -import six - -from six.moves.urllib import parse as urlparse # # This module defines a function called smart_open so we cannot use # smart_open.submodule to reference to the submodules. # -import smart_open.s3 as smart_open_s3 -import smart_open.hdfs as smart_open_hdfs -import smart_open.webhdfs as smart_open_webhdfs -import smart_open.http as smart_open_http -import smart_open.ssh as smart_open_ssh -import smart_open.gcs as smart_open_gcs +import smart_open.local_file as so_file +from smart_open import compression from smart_open import doctools +from smart_open import transport +from smart_open import utils -# Import ``pathlib`` if the builtin ``pathlib`` or the backport ``pathlib2`` are -# available. The builtin ``pathlib`` will be imported with higher precedence. -for pathlib_module in ('pathlib', 'pathlib2'): - try: - pathlib = importlib.import_module(pathlib_module) - PATHLIB_SUPPORT = True - break - except ImportError: - PATHLIB_SUPPORT = False +# +# For backwards compatibility and keeping old unit tests happy. +# +from smart_open.compression import register_compressor # noqa: F401 +from smart_open.utils import check_kwargs as _check_kwargs # noqa: F401 +from smart_open.utils import inspect_kwargs as _inspect_kwargs # noqa: F401 logger = logging.getLogger(__name__) SYSTEM_ENCODING = sys.getdefaultencoding() -_ISSUE_189_URL = 'https://github.com/RaRe-Technologies/smart_open/issues/189' +_TO_BINARY_LUT = { + 'r': 'rb', 'r+': 'rb+', 'rt': 'rb', 'rt+': 'rb+', + 'w': 'wb', 'w+': 'wb+', 'wt': 'wb', "wt+": 'wb+', + 'a': 'ab', 'a+': 'ab+', 'at': 'ab', 'at+': 'ab+', +} -_DEFAULT_S3_HOST = 's3.amazonaws.com' -_COMPRESSOR_REGISTRY = {} +def _sniff_scheme(uri_as_string): + """Returns the scheme of the URL only, as a string.""" + # + # urlsplit doesn't work on Windows -- it parses the drive as the scheme... + # no protocol given => assume a local file + # + if os.name == 'nt' and '://' not in uri_as_string: + uri_as_string = 'file://' + uri_as_string + + return urllib.parse.urlsplit(uri_as_string).scheme -def register_compressor(ext, callback): - """Register a callback for transparently decompressing files with a specific extension. +def parse_uri(uri_as_string): + """ + Parse the given URI from a string. Parameters ---------- - ext: str - The extension. - callback: callable - The callback. It must accept two position arguments, file_obj and mode. - - Examples - -------- - - Instruct smart_open to use the identity function whenever opening a file - with a .xz extension (see README.rst for the complete example showing I/O): - - >>> def _handle_xz(file_obj, mode): - ... import lzma - ... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ) - >>> - >>> register_compressor('.xz', _handle_xz) - - """ - if not (ext and ext[0] == '.'): - raise ValueError('ext must be a string starting with ., not %r' % ext) - if ext in _COMPRESSOR_REGISTRY: - logger.warning('overriding existing compression handler for %r', ext) - _COMPRESSOR_REGISTRY[ext] = callback + uri_as_string: str + The URI to parse. + Returns + ------- + collections.namedtuple + The parsed URI. -def _handle_bz2(file_obj, mode): - if six.PY2: - from bz2file import BZ2File - else: - from bz2 import BZ2File - return BZ2File(file_obj, mode) + Notes + ----- + Supported URI schemes are: -def _handle_gzip(file_obj, mode): - import gzip - return gzip.GzipFile(fileobj=file_obj, mode=mode) +%(schemes)s + s3, s3a and s3n are treated the same way. s3u is s3 but without SSL. + Valid URI examples:: -# -# NB. avoid using lambda here to make stack traces more readable. -# -register_compressor('.bz2', _handle_bz2) -register_compressor('.gz', _handle_gzip) - - -Uri = collections.namedtuple( - 'Uri', - ( - 'scheme', - 'uri_path', - 'bucket_id', - 'key_id', - 'blob_id', - 'port', - 'host', - 'ordinary_calling_format', - 'access_id', - 'access_secret', - 'user', - 'password', - ) -) -"""Represents all the options that we parse from user input. +%(uri_examples)s -Some of the above options only make sense for certain protocols, e.g. -bucket_id is only for S3 and GCS. -""" -# -# Set the default values for all Uri fields to be None. This allows us to only -# specify the relevant fields when constructing a Uri. -# -# https://stackoverflow.com/questions/11351032/namedtuple-and-default-values-for-optional-keyword-arguments -# -Uri.__new__.__defaults__ = (None,) * len(Uri._fields) + """ + scheme = _sniff_scheme(uri_as_string) + submodule = transport.get_transport(scheme) + as_dict = submodule.parse_uri(uri_as_string) -def _inspect_kwargs(kallable): # - # inspect.getargspec got deprecated in Py3.4, and calling it spews - # deprecation warnings that we'd prefer to avoid. Unfortunately, older - # versions of Python (<3.3) did not have inspect.signature, so we need to - # handle them the old-fashioned getargspec way. + # The conversion to a namedtuple is just to keep the old tests happy while + # I'm still refactoring. # - try: - signature = inspect.signature(kallable) - except AttributeError: - args, varargs, keywords, defaults = inspect.getargspec(kallable) - if not defaults: - return {} - supported_keywords = args[-len(defaults):] - return dict(zip(supported_keywords, defaults)) - else: - return { - name: param.default - for name, param in signature.parameters.items() - if param.default != inspect.Parameter.empty - } - + Uri = collections.namedtuple('Uri', sorted(as_dict.keys())) + return Uri(**as_dict) -def _check_kwargs(kallable, kwargs): - """Check which keyword arguments the callable supports. - - Parameters - ---------- - kallable: callable - A function or method to test - kwargs: dict - The keyword arguments to check. If the callable doesn't support any - of these, a warning message will get printed. - - Returns - ------- - dict - A dictionary of argument names and values supported by the callable. - """ - supported_keywords = sorted(_inspect_kwargs(kallable)) - unsupported_keywords = [k for k in sorted(kwargs) if k not in supported_keywords] - supported_kwargs = {k: v for (k, v) in kwargs.items() if k in supported_keywords} - - if unsupported_keywords: - logger.warning('ignoring unsupported keyword arguments: %r', unsupported_keywords) - - return supported_kwargs +# +# To keep old unit tests happy while I'm refactoring. +# +_parse_uri = parse_uri _builtin_open = open @@ -212,13 +130,8 @@ def open( ): r"""Open the URI object, returning a file-like object. - The URI is usually a string in a variety of formats: - - 1. a URI for the local filesystem: `./lines.txt`, `/home/joe/lines.txt.gz`, - `file:///home/joe/lines.txt.bz2` - 2. a URI for HDFS: `hdfs:///some/path/lines.txt` - 3. a URI for Amazon's S3 (can also supply credentials inside the URI): - `s3://my_bucket/lines.txt`, `s3://my_aws_key_id:key_secret@my_bucket/lines.txt` + The URI is usually a string in a variety of formats. + For a full list of examples, see the :func:`parse_uri` function. The URI may also be one of: @@ -226,10 +139,9 @@ def open( - a stream (anything that implements io.IOBase-like functionality) This function supports transparent compression and decompression using the - following codec: + following codecs: - - ``.gz`` - - ``.bz2`` +%(codecs)s The function depends on the file extension to determine the appropriate codec. @@ -294,7 +206,7 @@ def open( """ logger.debug('%r', locals()) - if not isinstance(mode, six.string_types): + if not isinstance(mode, str): raise TypeError('mode should be a string') if transport_params is None: @@ -322,8 +234,7 @@ def open( if encoding is not None and 'b' in mode: mode = mode.replace('b', '') - # Support opening ``pathlib.Path`` objects by casting them to strings. - if PATHLIB_SUPPORT and isinstance(uri, pathlib.Path): + if isinstance(uri, pathlib.Path): uri = str(uri) explicit_encoding = encoding @@ -341,20 +252,12 @@ def open( # filename ---------------> bytes -------------> bytes ---------> text # binary decompressed decode # - try: - binary_mode = {'r': 'rb', 'r+': 'rb+', - 'rt': 'rb', 'rt+': 'rb+', - 'w': 'wb', 'w+': 'wb+', - 'wt': 'wb', "wt+": 'wb+', - 'a': 'ab', 'a+': 'ab+', - 'at': 'ab', 'at+': 'ab+'}[mode] - except KeyError: - binary_mode = mode - binary, filename = _open_binary_stream(uri, binary_mode, transport_params) + binary_mode = _TO_BINARY_LUT.get(mode, mode) + binary = _open_binary_stream(uri, binary_mode, transport_params) if ignore_ext: decompressed = binary else: - decompressed = _compression_wrapper(binary, filename, mode) + decompressed = compression.compression_wrapper(binary, mode) if 'b' not in mode or explicit_encoding is not None: decoded = _encoding_wrapper(decompressed, mode, encoding=encoding, errors=errors) @@ -364,34 +267,6 @@ def open( return decoded -# -# The docstring can be None if -OO was passed to the interpreter. -# -open.__doc__ = None if open.__doc__ is None else open.__doc__ % { - 's3': doctools.to_docstring( - doctools.extract_kwargs(smart_open_s3.open.__doc__), - lpad=u' ', - ), - 'http': doctools.to_docstring( - doctools.extract_kwargs(smart_open_http.open.__doc__), - lpad=u' ', - ), - 'webhdfs': doctools.to_docstring( - doctools.extract_kwargs(smart_open_webhdfs.open.__doc__), - lpad=u' ', - ), - 'ssh': doctools.to_docstring( - doctools.extract_kwargs(smart_open_ssh.open.__doc__), - lpad=u' ', - ), - 'gcs': doctools.to_docstring( - doctools.extract_kwargs(smart_open_gcs.open.__doc__), - lpad=u' ', - ), - 'examples': doctools.extract_examples_from_readme_rst(), -} - - _MIGRATION_NOTES_URL = ( 'https://github.com/RaRe-Technologies/smart_open/blob/master/README.rst' '#migrating-to-the-new-open-function' @@ -415,7 +290,7 @@ def smart_open(uri, mode="rb", **kw): # ignore_extension = kw.pop('ignore_extension', False) - expected_kwargs = _inspect_kwargs(open) + expected_kwargs = utils.inspect_kwargs(open) scrubbed_kwargs = {} transport_params = {} @@ -492,15 +367,16 @@ def _shortcut_open( :returns: The opened file :rtype: file """ - if not isinstance(uri, six.string_types): + if not isinstance(uri, str): return None - parsed_uri = _parse_uri(uri) - if parsed_uri.scheme != 'file': + scheme = _sniff_scheme(uri) + if scheme not in (transport.NO_SCHEME, so_file.SCHEME): return None - _, extension = P.splitext(parsed_uri.uri_path) - if extension in _COMPRESSOR_REGISTRY and not ignore_ext: + local_path = so_file.extract_local_path(uri) + _, extension = P.splitext(local_path) + if extension in compression.get_supported_extensions() and not ignore_ext: return None open_kwargs = {} @@ -515,17 +391,7 @@ def _shortcut_open( if errors and 'b' not in mode: open_kwargs['errors'] = errors - # - # Under Py3, the built-in open accepts kwargs, and it's OK to use that. - # Under Py2, the built-in open _doesn't_ accept kwargs, but we still use it - # whenever possible (see issue #207). If we're under Py2 and have to use - # kwargs, then we have no option other to use io.open. - # - if six.PY3: - return _builtin_open(parsed_uri.uri_path, mode, buffering=buffering, **open_kwargs) - elif not open_kwargs: - return _builtin_open(parsed_uri.uri_path, mode, buffering=buffering) - return io.open(parsed_uri.uri_path, mode, buffering=buffering, **open_kwargs) + return _builtin_open(local_path, mode, buffering=buffering, **open_kwargs) def _open_binary_stream(uri, mode, transport_params): @@ -536,8 +402,8 @@ def _open_binary_stream(uri, mode, transport_params): :arg uri: The URI to open. May be a string, or something else. :arg str mode: The mode to open with. Must be rb, wb or ab. :arg transport_params: Keyword argumens for the transport layer. - :returns: A file object and the filename - :rtype: tuple + :returns: A named file object + :rtype: file-like object with a .name attribute """ if mode not in ('rb', 'rb+', 'wb', 'wb+', 'ab', 'ab+'): # @@ -546,351 +412,28 @@ def _open_binary_stream(uri, mode, transport_params): # raise NotImplementedError('unsupported mode: %r' % mode) - if isinstance(uri, six.string_types): - # this method just routes the request to classes handling the specific storage - # schemes, depending on the URI protocol in `uri` - filename = uri.split('/')[-1] - parsed_uri = _parse_uri(uri) - - if parsed_uri.scheme == "file": - fobj = io.open(parsed_uri.uri_path, mode) - return fobj, filename - elif parsed_uri.scheme in smart_open_ssh.SCHEMES: - fobj = smart_open_ssh.open( - parsed_uri.uri_path, - mode, - host=parsed_uri.host, - user=parsed_uri.user, - port=parsed_uri.port, - password=parsed_uri.password, - transport_params=transport_params, - ) - return fobj, filename - elif parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES: - return _s3_open_uri(parsed_uri, mode, transport_params), filename - elif parsed_uri.scheme == "hdfs": - _check_kwargs(smart_open_hdfs.open, transport_params) - return smart_open_hdfs.open(parsed_uri.uri_path, mode), filename - elif parsed_uri.scheme == "webhdfs": - kw = _check_kwargs(smart_open_webhdfs.open, transport_params) - http_uri = smart_open_webhdfs.convert_to_http_uri(parsed_uri) - return smart_open_webhdfs.open(http_uri, mode, **kw), filename - elif parsed_uri.scheme.startswith('http'): - # - # The URI may contain a query string and fragments, which interfere - # with our compressed/uncompressed estimation, so we strip them. - # - filename = P.basename(urlparse.urlparse(uri).path) - kw = _check_kwargs(smart_open_http.open, transport_params) - return smart_open_http.open(uri, mode, **kw), filename - elif parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME: - kw = _check_kwargs(smart_open_gcs.open, transport_params) - return smart_open_gcs.open(parsed_uri.bucket_id, parsed_uri.blob_id, mode, **kw), filename - else: - raise NotImplementedError("scheme %r is not supported", parsed_uri.scheme) - elif hasattr(uri, 'read'): + if hasattr(uri, 'read'): # simply pass-through if already a file-like # we need to return something as the file name, but we don't know what # so we probe for uri.name (e.g., this works with open() or tempfile.NamedTemporaryFile) - # if the value ends with COMPRESSED_EXT, we will note it in _compression_wrapper() + # if the value ends with COMPRESSED_EXT, we will note it in compression_wrapper() # if there is no such an attribute, we return "unknown" - this # effectively disables any compression - filename = getattr(uri, 'name', 'unknown') - return uri, filename - else: - raise TypeError("don't know how to handle uri %r" % uri) - - -def _s3_open_uri(uri, mode, transport_params): - logger.debug('s3_open_uri: %r', locals()) - if mode in ('r', 'w'): - raise ValueError('this function can only open binary streams. ' - 'Use smart_open.smart_open() to open text streams.') - elif mode not in ('rb', 'wb'): - raise NotImplementedError('unsupported mode: %r', mode) - - # - # There are two explicit ways we can receive session parameters from the user. - # - # 1. Via the session keyword argument (transport_params) - # 2. Via the URI itself - # - # They are not mutually exclusive, but we have to pick one of the two. - # Go with 1). - # - if transport_params.get('session') is not None and (uri.access_id or uri.access_secret): - logger.warning( - 'ignoring credentials parsed from URL because they conflict with ' - 'transport_params.session. Set transport_params.session to None ' - 'to suppress this warning.' - ) - elif (uri.access_id and uri.access_secret): - transport_params['session'] = boto3.Session( - aws_access_key_id=uri.access_id, - aws_secret_access_key=uri.access_secret, - ) - - # - # There are two explicit ways the user can provide the endpoint URI: - # - # 1. Via the URL. The protocol is implicit, and we assume HTTPS in this case. - # 2. Via the resource_kwargs and multipart_upload_kwargs endpoint_url parameter. - # - # Again, these are not mutually exclusive: the user can specify both. We - # have to pick one to proceed, however, and we go with 2. - # - if uri.host != _DEFAULT_S3_HOST: - endpoint_url = 'https://%s:%d' % (uri.host, uri.port) - _override_endpoint_url(transport_params, endpoint_url) - - kwargs = _check_kwargs(smart_open_s3.open, transport_params) - return smart_open_s3.open(uri.bucket_id, uri.key_id, mode, **kwargs) + if not hasattr(uri, 'name'): + uri.name = getattr(uri, 'name', 'unknown') + return uri + if not isinstance(uri, str): + raise TypeError("don't know how to handle uri %r" % uri) -def _override_endpoint_url(tp, url): - try: - resource_kwargs = tp['resource_kwargs'] - except KeyError: - resource_kwargs = tp['resource_kwargs'] = {} - - if resource_kwargs.get('endpoint_url'): - logger.warning( - 'ignoring endpoint_url parsed from URL because it conflicts ' - 'with transport_params.resource_kwargs.endpoint_url. ' - ) - else: - resource_kwargs.update(endpoint_url=url) - - -def _my_urlsplit(url): - """This is a hack to prevent the regular urlsplit from splitting around question marks. - - A question mark (?) in a URL typically indicates the start of a - querystring, and the standard library's urlparse function handles the - querystring separately. Unfortunately, question marks can also appear - _inside_ the actual URL for some schemas like S3. - - Replaces question marks with newlines prior to splitting. This is safe because: - - 1. The standard library's urlsplit completely ignores newlines - 2. Raw newlines will never occur in innocuous URLs. They are always URL-encoded. - - See Also - -------- - https://github.com/python/cpython/blob/3.7/Lib/urllib/parse.py - https://github.com/RaRe-Technologies/smart_open/issues/285 - """ - parsed_url = urlparse.urlsplit(url, allow_fragments=False) - if parsed_url.scheme not in smart_open_s3.SUPPORTED_SCHEMES or '?' not in url: - return parsed_url - - sr = urlparse.urlsplit(url.replace('?', '\n'), allow_fragments=False) - return urlparse.SplitResult(sr.scheme, sr.netloc, sr.path.replace('\n', '?'), '', '') - - -def _parse_uri(uri_as_string): - """ - Parse the given URI from a string. - - Supported URI schemes are: - - * file - * gs - * hdfs - * http - * https - * s3 - * s3a - * s3n - * s3u - * webhdfs - - .s3, s3a and s3n are treated the same way. s3u is s3 but without SSL. - - Valid URI examples:: - - * s3://my_bucket/my_key - * s3://my_key:my_secret@my_bucket/my_key - * s3://my_key:my_secret@my_server:my_port@my_bucket/my_key - * hdfs:///path/file - * hdfs://path/file - * webhdfs://host:port/path/file - * ./local/path/file - * ~/local/path/file - * local/path/file - * ./local/path/file.gz - * file:///home/user/file - * file:///home/user/file.bz2 - * [ssh|scp|sftp]://username@host//path/file - * [ssh|scp|sftp]://username@host/path/file - * gs://my_bucket/my_blob - - """ - if os.name == 'nt': - # urlsplit doesn't work on Windows -- it parses the drive as the scheme... - if '://' not in uri_as_string: - # no protocol given => assume a local file - uri_as_string = 'file://' + uri_as_string - - parsed_uri = _my_urlsplit(uri_as_string) - - if parsed_uri.scheme == "hdfs": - return _parse_uri_hdfs(parsed_uri) - elif parsed_uri.scheme == "webhdfs": - return parsed_uri - elif parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES: - return _parse_uri_s3x(parsed_uri) - elif parsed_uri.scheme == 'file': - return _parse_uri_file(parsed_uri.netloc + parsed_uri.path) - elif parsed_uri.scheme in ('', None): - return _parse_uri_file(uri_as_string) - elif parsed_uri.scheme.startswith('http'): - return Uri(scheme=parsed_uri.scheme, uri_path=uri_as_string) - elif parsed_uri.scheme in smart_open_ssh.SCHEMES: - return _parse_uri_ssh(parsed_uri) - elif parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME: - return _parse_uri_gcs(parsed_uri) - else: - raise NotImplementedError( - "unknown URI scheme %r in %r" % (parsed_uri.scheme, uri_as_string) - ) - - -def _parse_uri_hdfs(parsed_uri): - assert parsed_uri.scheme == 'hdfs' - uri_path = parsed_uri.netloc + parsed_uri.path - uri_path = "/" + uri_path.lstrip("/") - if not uri_path: - raise RuntimeError("invalid HDFS URI: %s" % str(parsed_uri)) - - return Uri(scheme='hdfs', uri_path=uri_path) - - -def _parse_uri_s3x(parsed_uri): - # - # Restrictions on bucket names and labels: - # - # - Bucket names must be at least 3 and no more than 63 characters long. - # - Bucket names must be a series of one or more labels. - # - Adjacent labels are separated by a single period (.). - # - Bucket names can contain lowercase letters, numbers, and hyphens. - # - Each label must start and end with a lowercase letter or a number. - # - # We use the above as a guide only, and do not perform any validation. We - # let boto3 take care of that for us. - # - assert parsed_uri.scheme in smart_open_s3.SUPPORTED_SCHEMES - - port = 443 - host = _DEFAULT_S3_HOST - ordinary_calling_format = False - # - # These defaults tell boto3 to look for credentials elsewhere - # - access_id, access_secret = None, None - - # - # Common URI template [secret:key@][host[:port]@]bucket/object - # - # The urlparse function doesn't handle the above schema, so we have to do - # it ourselves. - # - uri = parsed_uri.netloc + parsed_uri.path - - if '@' in uri and ':' in uri.split('@')[0]: - auth, uri = uri.split('@', 1) - access_id, access_secret = auth.split(':') - - head, key_id = uri.split('/', 1) - if '@' in head and ':' in head: - ordinary_calling_format = True - host_port, bucket_id = head.split('@') - host, port = host_port.split(':', 1) - port = int(port) - elif '@' in head: - ordinary_calling_format = True - host, bucket_id = head.split('@') - else: - bucket_id = head - - return Uri( - scheme=parsed_uri.scheme, bucket_id=bucket_id, key_id=key_id, - port=port, host=host, ordinary_calling_format=ordinary_calling_format, - access_id=access_id, access_secret=access_secret - ) - - -def _parse_uri_file(input_path): - # '~/tmp' may be expanded to '/Users/username/tmp' - uri_path = os.path.expanduser(input_path) - - if not uri_path: - raise RuntimeError("invalid file URI: %s" % input_path) - - return Uri(scheme='file', uri_path=uri_path) - - -def _parse_uri_ssh(unt): - """Parse a Uri from a urllib namedtuple.""" - return Uri( - scheme=unt.scheme, - uri_path=_unquote(unt.path), - user=_unquote(unt.username), - host=unt.hostname, - port=int(unt.port or smart_open_ssh.DEFAULT_PORT), - password=_unquote(unt.password), - ) - - -def _unquote(text): - return text and urlparse.unquote(text) - - -def _parse_uri_gcs(parsed_uri): - assert parsed_uri.scheme == smart_open_gcs.SUPPORTED_SCHEME - bucket_id, blob_id = parsed_uri.netloc, parsed_uri.path[1:] - - return Uri(scheme=parsed_uri.scheme, bucket_id=bucket_id, blob_id=blob_id) - - -def _need_to_buffer(file_obj, mode, ext): - """Returns True if we need to buffer the whole file in memory in order to proceed.""" - try: - is_seekable = file_obj.seekable() - except AttributeError: - # - # Under Py2, built-in file objects returned by open do not have - # .seekable, but have a .seek method instead. - # - is_seekable = hasattr(file_obj, 'seek') - return six.PY2 and mode.startswith('r') and ext in _COMPRESSOR_REGISTRY and not is_seekable - - -def _compression_wrapper(file_obj, filename, mode): - """ - This function will wrap the file_obj with an appropriate - [de]compression mechanism based on the extension of the filename. - - file_obj must either be a filehandle object, or a class which behaves - like one. + scheme = _sniff_scheme(uri) + submodule = transport.get_transport(scheme) + fobj = submodule.open_uri(uri, mode, transport_params) + if not hasattr(fobj, 'name'): + logger.critical('TODO') + fobj.name = 'unknown' - If the filename extension isn't recognized, will simply return the original - file_obj. - """ - _, ext = os.path.splitext(filename) - - if _need_to_buffer(file_obj, mode, ext): - warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL) - file_obj = io.BytesIO(file_obj.read()) - if ext in _COMPRESSOR_REGISTRY and mode.endswith('+'): - raise ValueError('transparent (de)compression unsupported for mode %r' % mode) - - try: - callback = _COMPRESSOR_REGISTRY[ext] - except KeyError: - return file_obj - else: - return callback(file_obj, mode) + return fobj def _encoding_wrapper(fileobj, mode, encoding=None, errors=None): @@ -945,10 +488,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _patch_pathlib(func): """Replace `Path.open` with `func`""" - if not PATHLIB_SUPPORT: - raise RuntimeError('install pathlib (or pathlib2) before using this function') - if six.PY2: - raise RuntimeError('this monkey patch does not work on Py2') old_impl = pathlib.Path.open pathlib.Path.open = func return old_impl + + +doctools.tweak_docstrings(open, parse_uri) diff --git a/smart_open/ssh.py b/smart_open/ssh.py index 6290f839..75f2c8fa 100644 --- a/smart_open/ssh.py +++ b/smart_open/ssh.py @@ -24,8 +24,11 @@ import getpass import logging +import urllib.parse import warnings +import smart_open.utils + logger = logging.getLogger(__name__) # @@ -38,6 +41,38 @@ DEFAULT_PORT = 22 +URI_EXAMPLES = ( + 'ssh://username@host/path/file', + 'ssh://username@host//path/file', + 'scp://username@host/path/file', + 'sftp://username@host/path/file', +) + + +def _unquote(text): + return text and urllib.parse.unquote(text) + + +def parse_uri(uri_as_string): + split_uri = urllib.parse.urlsplit(uri_as_string) + assert split_uri.scheme in SCHEMES + return dict( + scheme=split_uri.scheme, + uri_path=_unquote(split_uri.path), + user=_unquote(split_uri.username), + host=split_uri.hostname, + port=int(split_uri.port or DEFAULT_PORT), + password=_unquote(split_uri.password), + ) + + +def open_uri(uri, mode, transport_params): + smart_open.utils.check_kwargs(open, transport_params) + parsed_uri = parse_uri(uri) + uri_path = parsed_uri.pop('uri_path') + parsed_uri.pop('scheme') + return open(uri_path, mode, transport_params=transport_params, **parsed_uri) + def _connect(hostname, username, port, password, transport_params): try: @@ -106,4 +141,6 @@ def open(path, mode='r', host=None, user=None, password=None, port=DEFAULT_PORT, conn = _connect(host, user, port, password, transport_params) sftp_client = conn.get_transport().open_sftp_client() - return sftp_client.open(path, mode) + fobj = sftp_client.open(path, mode) + fobj.name = path + return fobj diff --git a/smart_open/tests/test_bytebuffer.py b/smart_open/tests/test_bytebuffer.py index 7b0344d5..42aa1687 100644 --- a/smart_open/tests/test_bytebuffer.py +++ b/smart_open/tests/test_bytebuffer.py @@ -9,23 +9,25 @@ import random import unittest -import six - import smart_open.bytebuffer CHUNK_SIZE = 1024 +def int2byte(i): + return bytes((i, )) + + def random_byte_string(length=CHUNK_SIZE): - rand_bytes = [six.int2byte(random.randint(0, 255)) for _ in range(length)] + rand_bytes = [int2byte(random.randint(0, 255)) for _ in range(length)] return b''.join(rand_bytes) def bytebuffer_and_random_contents(): buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) contents = random_byte_string(CHUNK_SIZE) - content_reader = six.BytesIO(contents) + content_reader = io.BytesIO(contents) buf.fill(content_reader) return [buf, contents] @@ -47,7 +49,7 @@ def test_len(self): def test_fill_from_reader(self): buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) contents = random_byte_string(CHUNK_SIZE) - content_reader = six.BytesIO(contents) + content_reader = io.BytesIO(contents) bytes_filled = buf.fill(content_reader) self.assertEqual(bytes_filled, CHUNK_SIZE) @@ -77,7 +79,7 @@ def test_fill_from_list(self): def test_fill_multiple(self): buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) long_contents = random_byte_string(CHUNK_SIZE * 4) - long_content_reader = six.BytesIO(long_contents) + long_content_reader = io.BytesIO(long_contents) first_bytes_filled = buf.fill(long_content_reader) self.assertEqual(first_bytes_filled, CHUNK_SIZE) @@ -89,7 +91,7 @@ def test_fill_multiple(self): def test_fill_size(self): buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) contents = random_byte_string(CHUNK_SIZE * 2) - content_reader = six.BytesIO(contents) + content_reader = io.BytesIO(contents) fill_size = int(CHUNK_SIZE / 2) bytes_filled = buf.fill(content_reader, size=fill_size) @@ -105,7 +107,7 @@ def test_fill_reader_exhaustion(self): buf = smart_open.bytebuffer.ByteBuffer(CHUNK_SIZE) short_content_size = int(CHUNK_SIZE / 4) short_contents = random_byte_string(short_content_size) - short_content_reader = six.BytesIO(short_contents) + short_content_reader = io.BytesIO(short_contents) bytes_filled = buf.fill(short_content_reader) self.assertEqual(bytes_filled, short_content_size) diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index c955f482..3fcd5a07 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -22,9 +22,9 @@ import google.cloud import google.api_core.exceptions -import six import smart_open +import smart_open.constants BUCKET_NAME = 'test-smartopen-{}'.format(uuid.uuid4().hex) BLOB_NAME = 'test-blob' @@ -42,8 +42,6 @@ def ignore_resource_warnings(): - if six.PY2: - return warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") # noqa @@ -174,8 +172,8 @@ def exists(self, client=None): def upload_from_string(self, data): # mimics Google's API by accepting bytes or str, despite the method name # https://google-cloud-python.readthedocs.io/en/0.32.0/storage/blobs.html#google.cloud.storage.blob.Blob.upload_from_string - if isinstance(data, six.string_types): - data = bytes(data) if six.PY2 else bytes(data, 'utf8') + if isinstance(data, str): + data = bytes(data, 'utf8') self.__contents = io.BytesIO(data) self.__contents.seek(0, io.SEEK_END) @@ -448,18 +446,30 @@ def mock_gcs(class_or_func): def mock_gcs_func(func): """Mock the function and provide additional required arguments.""" + assert callable(func), '%r is not a callable function' % func + def inner(*args, **kwargs): - with mock.patch('google.cloud.storage.Client', return_value=storage_client), \ - mock.patch( - 'smart_open.gcs.google_requests.AuthorizedSession', - return_value=FakeAuthorizedSession(storage_client._credentials), - ): - assert callable(func), 'you didn\'t provide a function!' - try: # is it a method that needs a self arg? - self_arg = inspect.signature(func).self - func(self_arg, *args, **kwargs) - except AttributeError: - func(*args, **kwargs) + # + # Is it a function or a method? The latter requires a self parameter. + # + signature = inspect.signature(func) + + fake_session = FakeAuthorizedSession(storage_client._credentials) + patched_client = mock.patch( + 'google.cloud.storage.Client', + return_value=storage_client, + ) + patched_session = mock.patch( + 'smart_open.gcs.google_requests.AuthorizedSession', + return_value=fake_session, + ) + + with patched_client, patched_session: + if not hasattr(signature, 'self'): + return func(*args, **kwargs) + else: + return func(signature.self, *args, **kwargs) + return inner @@ -564,7 +574,7 @@ def test_seek_current(self): fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) self.assertEqual(fin.read(5), b'hello') - seek = fin.seek(1, whence=smart_open.gcs.CURRENT) + seek = fin.seek(1, whence=smart_open.constants.WHENCE_CURRENT) self.assertEqual(seek, 6) self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) @@ -574,7 +584,7 @@ def test_seek_end(self): put_to_bucket(contents=content) fin = smart_open.gcs.SeekableBufferedInputBase(BUCKET_NAME, BLOB_NAME) - seek = fin.seek(-4, whence=smart_open.gcs.END) + seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') @@ -586,7 +596,7 @@ def test_detect_eof(self): fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) - fin.seek(0, whence=smart_open.gcs.END) + fin.seek(0, whence=smart_open.constants.WHENCE_END) self.assertEqual(eof, fin.tell()) def test_read_gzip(self): @@ -680,11 +690,12 @@ def test_write_01(self): with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fout: fout.write(test_string) - output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb")) + with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME), "rb") as fin: + output = list(fin) self.assertEqual(output, [test_string]) - def test_write_01a(self): + def test_incorrect_input(self): """Does gcs write fail on incorrect input?""" try: with smart_open.gcs.BufferedOutputBase(BUCKET_NAME, WRITE_BLOB_NAME) as fin: diff --git a/smart_open/tests/test_hdfs.py b/smart_open/tests/test_hdfs.py index c0fb8aab..a28ab991 100644 --- a/smart_open/tests/test_hdfs.py +++ b/smart_open/tests/test_hdfs.py @@ -14,7 +14,6 @@ import unittest import mock -import six import smart_open.hdfs @@ -56,7 +55,6 @@ def test_read_100(self): expected = 'В начале июля, в чрезвычайно жаркое время' self.assertEqual(expected, as_text) - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_unzip(self): path = P.join(CURR_DIR, 'test_data/crime-and-punishment.txt.gz') cat = subprocess.Popen(['cat', path], stdout=subprocess.PIPE) @@ -93,7 +91,6 @@ def test_write(self): actual = cat.stdout.read().decode('utf-8') self.assertEqual(as_text, actual) - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_zip(self): cat = subprocess.Popen(['cat'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) as_text = 'мы в ответе за тех, кого приручили' diff --git a/smart_open/tests/test_http.py b/smart_open/tests/test_http.py index e61e9dea..3527295d 100644 --- a/smart_open/tests/test_http.py +++ b/smart_open/tests/test_http.py @@ -11,6 +11,7 @@ import smart_open.http import smart_open.s3 +import smart_open.constants BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter' @@ -75,7 +76,7 @@ def test_seek_from_current(self): self.assertEqual(BYTES[10:20], read_bytes) self.assertEqual(reader.tell(), 20) - reader.seek(10, whence=smart_open.s3.CURRENT) + reader.seek(10, whence=smart_open.constants.WHENCE_CURRENT) self.assertEqual(reader.tell(), 30) read_bytes = reader.read(size=10) self.assertEqual(reader.tell(), 40) @@ -86,7 +87,7 @@ def test_seek_from_end(self): responses.add_callback(responses.GET, URL, callback=request_callback) reader = smart_open.http.SeekableBufferedInputBase(URL) - reader.seek(-10, whence=smart_open.s3.END) + reader.seek(-10, whence=smart_open.constants.WHENCE_END) self.assertEqual(reader.tell(), len(BYTES) - 10) read_bytes = reader.read(size=10) self.assertEqual(reader.tell(), len(BYTES)) @@ -144,6 +145,6 @@ def test_https_seek_reverse(self): with smart_open.open(HTTPS_URL, "rb") as fin: read_bytes_1 = fin.read(size=10) - fin.seek(-10, whence=smart_open.s3.CURRENT) + fin.seek(-10, whence=smart_open.constants.WHENCE_CURRENT) read_bytes_2 = fin.read(size=10) self.assertEqual(read_bytes_1, read_bytes_2) diff --git a/smart_open/tests/test_s3.py b/smart_open/tests/test_s3.py index d86c9924..223098a1 100644 --- a/smart_open/tests/test_s3.py +++ b/smart_open/tests/test_s3.py @@ -18,7 +18,6 @@ import botocore.client import mock import moto -import six import smart_open import smart_open.s3 @@ -77,12 +76,13 @@ def ignore_resource_warnings(): # https://github.com/boto/boto3/issues/454 # Py2 doesn't have ResourceWarning, so do nothing. # - if six.PY2: - return warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed.*") # noqa -@unittest.skipIf(not ENABLE_MOTO_SERVER, 'The test case needs a Moto server running on the local 5000 port.') +@unittest.skipUnless( + ENABLE_MOTO_SERVER, + 'The test case needs a Moto server running on the local 5000 port.' +) class SeekableRawReaderTest(unittest.TestCase): def setUp(self): @@ -180,7 +180,7 @@ def test_seek_current(self): fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) self.assertEqual(fin.read(5), b'hello') - seek = fin.seek(1, whence=smart_open.s3.CURRENT) + seek = fin.seek(1, whence=smart_open.constants.WHENCE_CURRENT) self.assertEqual(seek, 6) self.assertEqual(fin.read(6), u'wořld'.encode('utf-8')) @@ -190,7 +190,7 @@ def test_seek_end(self): put_to_bucket(contents=content) fin = smart_open.s3.SeekableBufferedInputBase(BUCKET_NAME, KEY_NAME) - seek = fin.seek(-4, whence=smart_open.s3.END) + seek = fin.seek(-4, whence=smart_open.constants.WHENCE_END) self.assertEqual(seek, len(content) - 4) self.assertEqual(fin.read(), b'you?') @@ -202,7 +202,7 @@ def test_detect_eof(self): fin.read() eof = fin.tell() self.assertEqual(eof, len(content)) - fin.seek(0, whence=smart_open.s3.END) + fin.seek(0, whence=smart_open.constants.WHENCE_END) self.assertEqual(eof, fin.tell()) def test_read_gzip(self): @@ -522,13 +522,6 @@ def test_flush_close(self): fout.close() -class ClampTest(unittest.TestCase): - def test(self): - self.assertEqual(smart_open.s3.clamp(5, 0, 10), 5) - self.assertEqual(smart_open.s3.clamp(11, 0, 10), 10) - self.assertEqual(smart_open.s3.clamp(-1, 0, 10), 0) - - ARBITRARY_CLIENT_ERROR = botocore.client.ClientError(error_response={}, operation_name='bar') @@ -615,15 +608,15 @@ def test_old(self): @moto.mock_s3 -@unittest.skipIf(not smart_open.s3._CONCURRENT_FUTURES, 'concurrent.futures unavailable') +@unittest.skipIf(not smart_open.concurrency._CONCURRENT_FUTURES, 'concurrent.futures unavailable') class IterBucketConcurrentFuturesTest(unittest.TestCase): def setUp(self): - self.old_flag_multi = smart_open.s3._MULTIPROCESSING - smart_open.s3._MULTIPROCESSING = False + self.old_flag_multi = smart_open.concurrency._MULTIPROCESSING + smart_open.concurrency._MULTIPROCESSING = False ignore_resource_warnings() def tearDown(self): - smart_open.s3._MULTIPROCESSING = self.old_flag_multi + smart_open.concurrency._MULTIPROCESSING = self.old_flag_multi cleanup_bucket() def test(self): @@ -637,15 +630,15 @@ def test(self): @moto.mock_s3 -@unittest.skipIf(not smart_open.s3._MULTIPROCESSING, 'multiprocessing unavailable') +@unittest.skipIf(not smart_open.concurrency._MULTIPROCESSING, 'multiprocessing unavailable') class IterBucketMultiprocessingTest(unittest.TestCase): def setUp(self): - self.old_flag_concurrent = smart_open.s3._CONCURRENT_FUTURES - smart_open.s3._CONCURRENT_FUTURES = False + self.old_flag_concurrent = smart_open.concurrency._CONCURRENT_FUTURES + smart_open.concurrency._CONCURRENT_FUTURES = False ignore_resource_warnings() def tearDown(self): - smart_open.s3._CONCURRENT_FUTURES = self.old_flag_concurrent + smart_open.concurrency._CONCURRENT_FUTURES = self.old_flag_concurrent cleanup_bucket() def test(self): @@ -661,16 +654,16 @@ def test(self): @moto.mock_s3 class IterBucketSingleProcessTest(unittest.TestCase): def setUp(self): - self.old_flag_multi = smart_open.s3._MULTIPROCESSING - self.old_flag_concurrent = smart_open.s3._CONCURRENT_FUTURES - smart_open.s3._MULTIPROCESSING = False - smart_open.s3._CONCURRENT_FUTURES = False + self.old_flag_multi = smart_open.concurrency._MULTIPROCESSING + self.old_flag_concurrent = smart_open.concurrency._CONCURRENT_FUTURES + smart_open.concurrency._MULTIPROCESSING = False + smart_open.concurrency._CONCURRENT_FUTURES = False ignore_resource_warnings() def tearDown(self): - smart_open.s3._MULTIPROCESSING = self.old_flag_multi - smart_open.s3._CONCURRENT_FUTURES = self.old_flag_concurrent + smart_open.concurrency._MULTIPROCESSING = self.old_flag_multi + smart_open.concurrency._CONCURRENT_FUTURES = self.old_flag_concurrent cleanup_bucket() def test(self): diff --git a/smart_open/tests/test_smart_open.py b/smart_open/tests/test_smart_open.py index c55be3e1..32e83bf8 100644 --- a/smart_open/tests/test_smart_open.py +++ b/smart_open/tests/test_smart_open.py @@ -19,7 +19,6 @@ from moto import mock_s3 import responses import gzip -import six import smart_open from smart_open import smart_open_lib @@ -288,7 +287,6 @@ def test_gs_uri_contains_slash(self): self.assertEqual(parsed_uri.bucket_id, "mybucket") self.assertEqual(parsed_uri.blob_id, "mydir/myblob") - @unittest.skipUnless(smart_open_lib.six.PY3, "our monkey patch only works on Py3") def test_pathlib_monkeypatch(self): from smart_open.smart_open_lib import pathlib @@ -305,7 +303,6 @@ def test_pathlib_monkeypatch(self): _patch_pathlib(obj.old_impl) assert pathlib.Path.open != smart_open.open - @unittest.skipUnless(smart_open_lib.six.PY3, "our monkey patch only works on Py3") def test_pathlib_monkeypath_read_gz(self): from smart_open.smart_open_lib import pathlib @@ -325,11 +322,6 @@ def test_pathlib_monkeypath_read_gz(self): finally: _patch_pathlib(obj.old_impl) - @unittest.skipUnless(smart_open_lib.six.PY2, 'this test is for Py2 only') - def test_monkey_patch_raises_exception_py2(self): - with self.assertRaises(RuntimeError): - patch_pathlib() - class SmartOpenHttpTest(unittest.TestCase): """ @@ -399,7 +391,6 @@ def _test_compressed_http(self, suffix, query): # decompress the file and get the same md5 hash self.assertEqual(smart_open_object.read(), raw_data) - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_http_gz(self): """Can open gzip via http?""" self._test_compressed_http(".gz", False) @@ -408,7 +399,6 @@ def test_http_bz2(self): """Can open bzip2 via http?""" self._test_compressed_http(".bz2", False) - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_http_gz_query(self): """Can open gzip via http with a query appended to URI?""" self._test_compressed_http(".gz", True) @@ -418,7 +408,7 @@ def test_http_bz2_query(self): self._test_compressed_http(".bz2", True) -def make_buffer(cls=six.BytesIO, initial_value=None, name=None, noclose=False): +def make_buffer(cls=io.BytesIO, initial_value=None, name=None, noclose=False): """ Construct a new in-memory file object aka "buf". @@ -431,9 +421,6 @@ def make_buffer(cls=six.BytesIO, initial_value=None, name=None, noclose=False): buf = cls(initial_value) if initial_value else cls() if name is not None: buf.name = name - if six.PY2: - buf.__enter__ = lambda: buf - buf.__exit__ = lambda exc_type, exc_val, exc_tb: None if noclose: buf.close = lambda: None return buf @@ -515,7 +502,6 @@ def test_write_bytes(self): sf.write(SAMPLE_BYTES) self.assertEqual(buf.getvalue(), SAMPLE_BYTES) - @unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes") def test_read_text_stream_fails(self): """Attempts to read directly from a text stream should fail. @@ -523,14 +509,13 @@ def test_read_text_stream_fails(self): If you have a text stream, there's no point passing it to smart_open: you can read from it directly. """ - buf = make_buffer(six.StringIO, initial_value=SAMPLE_TEXT) + buf = make_buffer(io.StringIO, initial_value=SAMPLE_TEXT) with smart_open.smart_open(buf, 'r') as sf: self.assertRaises(TypeError, sf.read) # we expect binary mode - @unittest.skipIf(six.PY2, "Python 2 does not differentiate between str and bytes") def test_write_text_stream_fails(self): """Attempts to write directly to a text stream should fail.""" - buf = make_buffer(six.StringIO) + buf = make_buffer(io.StringIO) with smart_open.smart_open(buf, 'w') as sf: self.assertRaises(TypeError, sf.write, SAMPLE_TEXT) # we expect binary mode @@ -649,7 +634,6 @@ def test_open_with_keywords_explicit_r(self): actual = fin.read() self.assertEqual(expected, actual) - @unittest.skipUnless(smart_open_lib.PATHLIB_SUPPORT, "this test requires pathlib") def test_open_and_read_pathlib_path(self): """If ``pathlib.Path`` is available we should be able to open and read.""" from smart_open.smart_open_lib import pathlib @@ -762,7 +746,7 @@ def test_file(self, mock_smart_open): short_path = "~/tmp/test.txt" full_path = os.path.expanduser(short_path) - @mock.patch(_IO_OPEN if six.PY2 else _BUILTIN_OPEN) + @mock.patch(_BUILTIN_OPEN) def test_file_errors(self, mock_smart_open): prefix = "file://" full_path = '/tmp/test.txt' @@ -1003,8 +987,7 @@ def test_file_mode_mock(self): # def test_text(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("blah", "r", encoding='utf-8') as fin: self.assertEqual(fin.read(), self.as_text) mock_open.assert_called_with("blah", "r", buffering=-1, encoding='utf-8') @@ -1033,22 +1016,19 @@ def test_incorrect(self): def test_write_utf8(self): # correct write mode, correct file:// URI - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("blah", "w", encoding='utf-8') as fout: mock_open.assert_called_with("blah", "w", buffering=-1, encoding='utf-8') fout.write(self.as_text) def test_write_utf8_absolute_path(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("/some/file.txt", "w", encoding='utf-8') as fout: mock_open.assert_called_with("/some/file.txt", "w", buffering=-1, encoding='utf-8') fout.write(self.as_text) def test_append_utf8(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("/some/file.txt", "w+", encoding='utf-8') as fout: mock_open.assert_called_with("/some/file.txt", "w+", buffering=-1, encoding='utf-8') fout.write(self.as_text) @@ -1326,29 +1306,13 @@ def cleanup_temp_bz2(self, test_file): os.unlink(test_file) def test_can_read_multistream_bz2(self): - if six.PY2: - # this is a backport from Python 3 - from bz2file import BZ2File - else: - from bz2 import BZ2File + from bz2 import BZ2File test_file = self.create_temp_bz2(streams=5) with BZ2File(test_file) as bz2f: self.assertEqual(bz2f.read(), self.TEXT * 5) self.cleanup_temp_bz2(test_file) - def test_python2_stdlib_bz2_cannot_read_multistream(self): - # Multistream bzip is included in Python 3 - if not six.PY2: - return - import bz2 - - test_file = self.create_temp_bz2(streams=5) - bz2f = bz2.BZ2File(test_file) - self.assertNotEqual(bz2f.read(), self.TEXT * 5) - bz2f.close() - self.cleanup_temp_bz2(test_file) - def test_file_smart_open_can_read_multistream_bz2(self): test_file = self.create_temp_bz2(streams=5) with smart_open_lib.smart_open(test_file) as bz2f: diff --git a/smart_open/tests/test_smart_open_old.py b/smart_open/tests/test_smart_open_old.py index 31fe50f8..7fa8c30e 100644 --- a/smart_open/tests/test_smart_open_old.py +++ b/smart_open/tests/test_smart_open_old.py @@ -25,7 +25,6 @@ from moto import mock_s3 import responses import gzip -import six import smart_open from smart_open import smart_open_lib @@ -70,7 +69,6 @@ def test_http_pass(self): self.assertTrue(actual_request.headers['Authorization'].startswith('Basic ')) @responses.activate - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_http_gz(self): """Can open gzip via http?""" fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') @@ -88,7 +86,6 @@ def test_http_gz(self): self.assertEqual(m.hexdigest(), expected_hash) @responses.activate - @unittest.skipIf(six.PY2, 'gzip support for Py2 is not implemented yet') def test_http_gz_noquerystring(self): """Can open gzip via http?""" fpath = os.path.join(CURR_DIR, 'test_data/crlf_at_1k_boundary.warc.gz') @@ -171,9 +168,6 @@ def test_open_with_keywords_explicit_r(self): actual = fin.read() self.assertEqual(expected, actual) - @unittest.skipUnless( - smart_open_lib.PATHLIB_SUPPORT, - "do not test pathlib support if pathlib or backport are not available") def test_open_and_read_pathlib_path(self): """If ``pathlib.Path`` is available we should be able to open and read.""" from smart_open.smart_open_lib import pathlib @@ -286,7 +280,7 @@ def test_file(self, mock_smart_open): short_path = "~/tmp/test.txt" full_path = os.path.expanduser(short_path) - @mock.patch(_IO_OPEN if six.PY2 else _BUILTIN_OPEN) + @mock.patch(_BUILTIN_OPEN) def test_file_errors(self, mock_smart_open): prefix = "file://" full_path = '/tmp/test.txt' @@ -546,8 +540,7 @@ def test_file_mode_mock(self): # def test_text(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("blah", "r", encoding='utf-8') as fin: self.assertEqual(fin.read(), self.as_text) mock_open.assert_called_with("blah", "r", buffering=-1, encoding='utf-8') @@ -576,22 +569,19 @@ def test_incorrect(self): def test_write_utf8(self): # correct write mode, correct file:// URI - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("blah", "w", encoding='utf-8') as fout: mock_open.assert_called_with("blah", "w", buffering=-1, encoding='utf-8') fout.write(self.as_text) def test_write_utf8_absolute_path(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("/some/file.txt", "w", encoding='utf-8') as fout: mock_open.assert_called_with("/some/file.txt", "w", buffering=-1, encoding='utf-8') fout.write(self.as_text) def test_append_utf8(self): - patch = _IO_OPEN if six.PY2 else _BUILTIN_OPEN - with mock.patch(patch, mock.Mock(return_value=self.stringio)) as mock_open: + with mock.patch(_BUILTIN_OPEN, mock.Mock(return_value=self.stringio)) as mock_open: with smart_open.smart_open("/some/file.txt", "w+", encoding='utf-8') as fout: mock_open.assert_called_with("/some/file.txt", "w+", buffering=-1, encoding='utf-8') fout.write(self.as_text) @@ -973,7 +963,6 @@ def test_rw_gzip(self): with smart_open.smart_open(key, "rb") as fin: self.assertEqual(fin.read().decode("utf-8"), text) - @unittest.skipIf(six.PY2, 'this test does not work with Py2') @mock_s3 def test_gzip_write_mode(self): """Should always open in binary mode when writing through a codec.""" @@ -984,7 +973,6 @@ def test_gzip_write_mode(self): smart_open.smart_open("s3://bucket/key.gz", "wb") mock_open.assert_called_with('bucket', 'key.gz', 'wb') - @unittest.skipIf(six.PY2, 'this test does not work with Py2') @mock_s3 def test_gzip_read_mode(self): """Should always open in binary mode when reading through a codec.""" diff --git a/smart_open/tests/test_utils.py b/smart_open/tests/test_utils.py new file mode 100644 index 00000000..63b463d8 --- /dev/null +++ b/smart_open/tests/test_utils.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2019 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# + +import unittest + +import smart_open.utils + + +class ClampTest(unittest.TestCase): + def test_low(self): + self.assertEqual(smart_open.utils.clamp(5, 0, 10), 5) + + def test_high(self): + self.assertEqual(smart_open.utils.clamp(11, 0, 10), 10) + + def test_out_of_range(self): + self.assertEqual(smart_open.utils.clamp(-1, 0, 10), 0) diff --git a/smart_open/transport.py b/smart_open/transport.py new file mode 100644 index 00000000..2d43d862 --- /dev/null +++ b/smart_open/transport.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# +"""Maintains a registry of transport mechanisms. + +The main entrypoint is :func:`get_transport`. See also :file:`extending.md`. + +""" +import importlib +import logging + +import smart_open.local_file + +logger = logging.getLogger(__name__) + +NO_SCHEME = '' + +_REGISTRY = {NO_SCHEME: smart_open.local_file} + + +def register_transport(submodule): + """Register a submodule as a transport mechanism for ``smart_open``. + + This module **must** have: + + - `SCHEME` attribute (or `SCHEMES`, if the submodule supports multiple schemes) + - `open` function + - `open_uri` function + - `parse_uri' function + + Once registered, you can get the submodule by calling :func:`get_transport`. + + """ + global _REGISTRY + if isinstance(submodule, str): + try: + submodule = importlib.import_module(submodule) + except ImportError: + logger.warning('unable to import %r, disabling that module', submodule) + return + + if hasattr(submodule, 'SCHEME'): + schemes = [submodule.SCHEME] + elif hasattr(submodule, 'SCHEMES'): + schemes = submodule.SCHEMES + else: + raise ValueError('%r does not have a .SCHEME or .SCHEMES attribute' % submodule) + + for f in ('open', 'open_uri', 'parse_uri'): + assert hasattr(submodule, f), '%r is missing %r' % (submodule, f) + + for scheme in schemes: + assert scheme not in _REGISTRY + _REGISTRY[scheme] = submodule + + +def get_transport(scheme): + """Get the submodule that handles transport for the specified scheme. + + This submodule must have been previously registered via :func:`register_transport`. + + """ + message = "scheme %r is not supported, expected one of %r" % (scheme, SUPPORTED_SCHEMES) + + try: + submodule = _REGISTRY[scheme] + except KeyError: + raise NotImplementedError(message) + else: + return submodule + + +register_transport(smart_open.local_file) +register_transport('smart_open.gcs') +register_transport('smart_open.hdfs') +register_transport('smart_open.http') +register_transport('smart_open.s3') +register_transport('smart_open.ssh') +register_transport('smart_open.webhdfs') + +SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys())) +"""The transport schemes that the local installation of ``smart_open`` supports.""" diff --git a/smart_open/utils.py b/smart_open/utils.py new file mode 100644 index 00000000..dfd44ecb --- /dev/null +++ b/smart_open/utils.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) 2020 Radim Rehurek +# +# This code is distributed under the terms and conditions +# from the MIT License (MIT). +# + +"""Helper functions for documentation, etc.""" + +import inspect +import logging + +logger = logging.getLogger(__name__) + + +def inspect_kwargs(kallable): + # + # inspect.getargspec got deprecated in Py3.4, and calling it spews + # deprecation warnings that we'd prefer to avoid. Unfortunately, older + # versions of Python (<3.3) did not have inspect.signature, so we need to + # handle them the old-fashioned getargspec way. + # + try: + signature = inspect.signature(kallable) + except AttributeError: + try: + args, varargs, keywords, defaults = inspect.getargspec(kallable) + except TypeError: + # + # Happens under Py2.7 with mocking. + # + return {} + + if not defaults: + return {} + supported_keywords = args[-len(defaults):] + return dict(zip(supported_keywords, defaults)) + else: + return { + name: param.default + for name, param in signature.parameters.items() + if param.default != inspect.Parameter.empty + } + + +def check_kwargs(kallable, kwargs): + """Check which keyword arguments the callable supports. + + Parameters + ---------- + kallable: callable + A function or method to test + kwargs: dict + The keyword arguments to check. If the callable doesn't support any + of these, a warning message will get printed. + + Returns + ------- + dict + A dictionary of argument names and values supported by the callable. + """ + supported_keywords = sorted(inspect_kwargs(kallable)) + unsupported_keywords = [k for k in sorted(kwargs) if k not in supported_keywords] + supported_kwargs = {k: v for (k, v) in kwargs.items() if k in supported_keywords} + + if unsupported_keywords: + logger.warning('ignoring unsupported keyword arguments: %r', unsupported_keywords) + + return supported_kwargs + + +def clamp(value, minval, maxval): + """Clamp a numeric value to a specific range. + + Parameters + ---------- + value: numeric + The value to clamp. + + minval: numeric + The lower bound. + + maxval: numeric + The upper bound. + + Returns + ------- + numeric + The clamped value. It will be in the range ``[minval, maxval]``. + + """ + return max(min(value, maxval), minval) + + +def make_range_string(start, stop=None): + """Create a byte range specifier in accordance with RFC-2616. + + Parameters + ---------- + start: int + The start of the byte range + + stop: int, optional + The end of the byte range. If unspecified, indicates EOF. + + Returns + ------- + str + A byte range specifier. + + """ + # + # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 + # + if stop is None: + return 'bytes=%d-' % start + return 'bytes=%d-%d' % (start, stop) diff --git a/smart_open/webhdfs.py b/smart_open/webhdfs.py index d24d3554..d9785ff0 100644 --- a/smart_open/webhdfs.py +++ b/smart_open/webhdfs.py @@ -14,22 +14,35 @@ import io import logging +import urllib.parse import requests -import six -from six.moves.urllib import parse as urlparse -if six.PY2: - import httplib -else: - import http.client as httplib +from smart_open import utils, constants + +import http.client as httplib logger = logging.getLogger(__name__) -WEBHDFS_MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads +SCHEME = 'webhdfs' + +URI_EXAMPLES = ( + 'webhdfs://host:port/path/file', +) + +MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads + +def parse_uri(uri_as_str): + return dict(scheme=SCHEME, uri=uri_as_str) -def open(http_uri, mode, min_part_size=WEBHDFS_MIN_PART_SIZE): + +def open_uri(uri, mode, transport_params): + kwargs = utils.check_kwargs(open, transport_params) + return open(uri, mode, **kwargs) + + +def open(http_uri, mode, min_part_size=MIN_PART_SIZE): """ Parameters ---------- @@ -39,37 +52,51 @@ def open(http_uri, mode, min_part_size=WEBHDFS_MIN_PART_SIZE): For writing only. """ - if mode == 'rb': - return BufferedInputBase(http_uri) - elif mode == 'wb': - return BufferedOutputBase(http_uri, min_part_size=min_part_size) + if http_uri.startswith(SCHEME): + http_uri = _convert_to_http_uri(http_uri) + + if mode == constants.READ_BINARY: + fobj = BufferedInputBase(http_uri) + elif mode == constants.WRITE_BINARY: + fobj = BufferedOutputBase(http_uri, min_part_size=min_part_size) else: raise NotImplementedError("webhdfs support for mode %r not implemented" % mode) + fobj.name = http_uri.split('/')[-1] + return fobj -def convert_to_http_uri(parsed_uri): + +def _convert_to_http_uri(webhdfs_url): """ Convert webhdfs uri to http url and return it as text Parameters ---------- - parsed_uri: str - result of urlsplit of webhdfs url + webhdfs_url: str + A URL starting with webhdfs:// """ - netloc = parsed_uri.hostname - if parsed_uri.port: - netloc += ":{}".format(parsed_uri.port) - query = parsed_uri.query - if parsed_uri.username: + split_uri = urllib.parse.urlsplit(webhdfs_url) + netloc = split_uri.hostname + if split_uri.port: + netloc += ":{}".format(split_uri.port) + query = split_uri.query + if split_uri.username: query += ( - ("&" if query else "") + "user.name=" + urlparse.quote(parsed_uri.username) + ("&" if query else "") + "user.name=" + urllib.parse.quote(split_uri.username) ) - return urlparse.urlunsplit( - ("http", netloc, "/webhdfs/v1" + parsed_uri.path, query, "") + return urllib.parse.urlunsplit( + ("http", netloc, "/webhdfs/v1" + split_uri.path, query, "") ) +# +# For old unit tests. +# +def convert_to_http_uri(parsed_uri): + return _convert_to_http_uri(parsed_uri.uri) + + class BufferedInputBase(io.BufferedIOBase): def __init__(self, uri): self._uri = uri @@ -140,7 +167,7 @@ def readline(self): class BufferedOutputBase(io.BufferedIOBase): - def __init__(self, uri, min_part_size=WEBHDFS_MIN_PART_SIZE): + def __init__(self, uri, min_part_size=MIN_PART_SIZE): """ Parameters ---------- @@ -202,7 +229,7 @@ def write(self, b): if self._closed: raise ValueError("I/O operation on closed file") - if not isinstance(b, six.binary_type): + if not isinstance(b, bytes): raise TypeError("input must be a binary string") self.lines.append(b) diff --git a/tox.ini b/tox.ini index 2b51b1b7..a9cf0a23 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] minversion = 2.0 -envlist = py{27,35,36,37}-{test,doctest,integration,benchmark}, sdist, flake8 +envlist = py{35,36,37}-{test,doctest,integration,benchmark}, sdist, flake8 [pytest] addopts = -rfxEXs --durations=20 --showlocals --reruns 3 --reruns-delay 1